06. Transfer Learning with Tensorflow Part 3: Scaling up (Food Vision mini)¶
We've seen how powerful transfer learning is in the part 1 and 2 notebooks. They were all small modelling experiments, so it's time to step up a bit.
It's common practice to practice ML and deep learning, by getting a model to work on a small subset of data, before scaling to a larger/full dataset.
We will scale up from 10 categories in Food101, to everything else.
Our goal is to beat the Food101 paper with 10% of data.

ML practitioners are serial experimenters. Start small, get a model working, see how it goes, and then gradually scale up to your end goal.
What we're going to cover¶
We're gonna go through the following:
- Downloading and preparing 10% of Food101 data
- Training a feature extraction transfer learning model on 10% of the Food101 training data
- Fine-tuning our feature extraction model
- Saving and loaded our trained model
- Evaluating the performance of our Food Vision model
- Find the model's worst performing predictions
- Making predictions with our Food Vision model on custom images of food
# are we using gpu?
!nvidia-smi
Mon Oct 27 19:42:18 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94 Driver Version: 560.94 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce GTX 1060 6GB WDDM | 00000000:0A:00.0 On | N/A |
| 0% 54C P0 30W / 120W | 1661MiB / 6144MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1764 C+G ...ejd91yc\AdobeNotificationClient.exe N/A |
| 0 N/A N/A 1844 C+G ...GeForce Experience\NVIDIA Share.exe N/A |
| 0 N/A N/A 3428 C ...elchupacabra\App\Bandicam\bdcam.exe N/A |
| 0 N/A N/A 7528 C+G ...\Adobe Photoshop 2021\Photoshop.exe N/A |
| 0 N/A N/A 10652 C+G ...GeForce Experience\NVIDIA Share.exe N/A |
| 0 N/A N/A 11660 C+G ...oogle\Chrome\Application\chrome.exe N/A |
| 0 N/A N/A 14488 C+G ... Files\Elgato\WaveLink\WaveLink.exe N/A |
| 0 N/A N/A 17812 C+G ...remium\win64\bin\HarmonyPremium.exe N/A |
| 0 N/A N/A 18036 C+G ...cal\Microsoft\OneDrive\OneDrive.exe N/A |
| 0 N/A N/A 18352 C+G X:\Mozilla_Thunderbird\thunderbird.exe N/A |
| 0 N/A N/A 19188 C+G ...soft Office\root\Office16\EXCEL.EXE N/A |
| 0 N/A N/A 19676 C+G ...2.0_x64__cv1g1gvanyjgm\WhatsApp.exe N/A |
| 0 N/A N/A 20236 C+G ...dobe\Adobe Animate 2021\Animate.exe N/A |
| 0 N/A N/A 20584 C+G ...cal\Microsoft\OneDrive\OneDrive.exe N/A |
| 0 N/A N/A 21080 C+G ...\cef\cef.win7x64\steamwebhelper.exe N/A |
| 0 N/A N/A 25892 C+G ...oogle\Chrome\Application\chrome.exe N/A |
| 0 N/A N/A 36744 C+G ...t.LockApp_cw5n1h2txyewy\LockApp.exe N/A |
| 0 N/A N/A 37408 C+G ...al\Discord\app-1.0.9205\Discord.exe N/A |
| 0 N/A N/A 38768 C+G ...2txyewy\StartMenuExperienceHost.exe N/A |
| 0 N/A N/A 43584 C+G ....Search_cw5n1h2txyewy\SearchApp.exe N/A |
| 0 N/A N/A 47300 C+G ...ekyb3d8bbwe\PhoneExperienceHost.exe N/A |
| 0 N/A N/A 50096 C+G ...-ins\Spaces\Adobe Spaces Helper.exe N/A |
| 0 N/A N/A 52420 C+G ...CBS_cw5n1h2txyewy\TextInputHost.exe N/A |
| 0 N/A N/A 52568 C+G ...\DAUM\PotPlayer\PotPlayerMini64.exe N/A |
| 0 N/A N/A 53760 C+G ...5n1h2txyewy\ShellExperienceHost.exe N/A |
| 0 N/A N/A 56388 C+G ...05.0_x64__8wekyb3d8bbwe\Cortana.exe N/A |
| 0 N/A N/A 64072 C+G ...crosoft\Edge\Application\msedge.exe N/A |
| 0 N/A N/A 64204 C+G ...1.0_x64__8wekyb3d8bbwe\Video.UI.exe N/A |
| 0 N/A N/A 65052 C+G ...on\141.0.3537.85\msedgewebview2.exe N/A |
| 0 N/A N/A 66376 C+G ...remium\win64\bin\HarmonyPremium.exe N/A |
| 0 N/A N/A 66948 C+G C:\Windows\explorer.exe N/A |
| 0 N/A N/A 69428 C+G ....Search_cw5n1h2txyewy\SearchApp.exe N/A |
| 0 N/A N/A 71684 C+G X:\Microsoft VS Code\Code.exe N/A |
+-----------------------------------------------------------------------------------------+
import datetime
print(f'Notebook last run (end-to-end): {datetime.datetime.now()}')
Notebook last run (end-to-end): 2025-11-03 23:04:30.566595
Creating helper functions¶
We have a file with all useful functions that can come in handy for us. It'll be tedious to rewrite them all, and better to just import the .py file.
# get helper_functions.py script from course Github
!curl -O https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
# import helper functions we're going to use
import sys, os
sys.path.append(os.getcwd())
import sklearn
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, walk_through_dir, compare_historys
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 10246 100 10246 0 0 23537 0 --:--:-- --:--:-- --:--:-- 23717
101 Food Classes: Working with less data¶
So far, our previous experiments in transfer learning has worked quite well in 10 classes of food data. So it's time to make the jump for the full 101 classes.
The original Food101 has 1000 images per class, 750 for train, 250 for test, totalling 101,000 unique pics.
We can use the full dataset. But in the spirit of experimentation, we'll only use 10% of training data, and see how it does.
This means only 75 images per 101 classes for training, while keeping the original 250 test data.
Downloading and preprocessing data¶
We'll download a subset of Food101 dataset, which will come as a zip file. We will use unzip_data() function to unzip it.
# download data from google dtorage
!curl -O https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip
unzip_data('101_food_classes_10_percent.zip')
train_dir = '101_food_classes_10_percent/train/'
test_dir = '101_food_classes_10_percent/test/'
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
0 1550M 0 7540 0 0 9748 0 46:19:03 --:--:-- 46:19:03 9779
0 1550M 0 1278k 0 0 731k 0 0:36:11 0:00:01 0:36:10 731k
0 1550M 0 12.4M 0 0 4656k 0 0:05:40 0:00:02 0:05:38 4660k
1 1550M 1 24.2M 0 0 6617k 0 0:03:59 0:00:03 0:03:56 6620k
2 1550M 2 36.4M 0 0 7875k 0 0:03:21 0:00:04 0:03:17 7879k
3 1550M 3 48.7M 0 0 8658k 0 0:03:03 0:00:05 0:02:58 9998k
3 1550M 3 59.2M 0 0 8985k 0 0:02:56 0:00:06 0:02:50 11.5M
4 1550M 4 69.8M 0 0 9235k 0 0:02:51 0:00:07 0:02:44 11.4M
5 1550M 5 81.0M 0 0 9493k 0 0:02:47 0:00:08 0:02:39 11.3M
6 1550M 6 93.1M 0 0 9778k 0 0:02:42 0:00:09 0:02:33 11.3M
6 1550M 6 105M 0 0 9.7M 0 0:02:38 0:00:10 0:02:28 11.3M
7 1550M 7 117M 0 0 9.9M 0 0:02:35 0:00:11 0:02:24 11.5M
8 1550M 8 128M 0 0 10.0M 0 0:02:34 0:00:12 0:02:22 11.6M
8 1550M 8 139M 0 0 10.1M 0 0:02:32 0:00:13 0:02:19 11.6M
9 1550M 9 151M 0 0 10.3M 0 0:02:30 0:00:14 0:02:16 11.7M
10 1550M 10 163M 0 0 10.3M 0 0:02:29 0:00:15 0:02:14 11.6M
11 1550M 11 175M 0 0 10.4M 0 0:02:28 0:00:16 0:02:12 11.6M
11 1550M 11 185M 0 0 10.4M 0 0:02:28 0:00:17 0:02:11 11.5M
12 1550M 12 197M 0 0 10.5M 0 0:02:27 0:00:18 0:02:09 11.6M
13 1550M 13 209M 0 0 10.6M 0 0:02:25 0:00:19 0:02:06 11.5M
14 1550M 14 221M 0 0 10.6M 0 0:02:25 0:00:20 0:02:05 11.6M
15 1550M 15 233M 0 0 10.7M 0 0:02:24 0:00:21 0:02:03 11.7M
15 1550M 15 244M 0 0 10.7M 0 0:02:24 0:00:22 0:02:02 11.8M
16 1550M 16 257M 0 0 10.8M 0 0:02:23 0:00:23 0:02:00 11.8M
17 1550M 17 269M 0 0 10.8M 0 0:02:22 0:00:24 0:01:58 11.8M
18 1550M 18 280M 0 0 10.8M 0 0:02:22 0:00:25 0:01:57 11.7M
18 1550M 18 290M 0 0 10.8M 0 0:02:22 0:00:26 0:01:56 11.3M
19 1550M 19 301M 0 0 10.8M 0 0:02:22 0:00:27 0:01:55 11.3M
20 1550M 20 313M 0 0 10.8M 0 0:02:22 0:00:28 0:01:54 11.2M
20 1550M 20 325M 0 0 10.9M 0 0:02:21 0:00:29 0:01:52 11.1M
21 1550M 21 337M 0 0 10.9M 0 0:02:21 0:00:30 0:01:51 11.3M
22 1550M 22 348M 0 0 10.9M 0 0:02:21 0:00:31 0:01:50 11.6M
23 1550M 23 360M 0 0 10.9M 0 0:02:20 0:00:32 0:01:48 11.7M
23 1550M 23 370M 0 0 10.9M 0 0:02:21 0:00:33 0:01:48 11.4M
24 1550M 24 382M 0 0 10.9M 0 0:02:20 0:00:34 0:01:46 11.4M
25 1550M 25 393M 0 0 11.0M 0 0:02:20 0:00:35 0:01:45 11.2M
26 1550M 26 404M 0 0 11.0M 0 0:02:20 0:00:36 0:01:44 11.1M
26 1550M 26 416M 0 0 11.0M 0 0:02:20 0:00:37 0:01:43 11.2M
27 1550M 27 426M 0 0 11.0M 0 0:02:20 0:00:38 0:01:42 11.1M
28 1550M 28 438M 0 0 11.0M 0 0:02:20 0:00:39 0:01:41 11.2M
29 1550M 29 450M 0 0 11.0M 0 0:02:20 0:00:40 0:01:40 11.3M
29 1550M 29 461M 0 0 11.0M 0 0:02:20 0:00:41 0:01:39 11.4M
30 1550M 30 474M 0 0 11.0M 0 0:02:19 0:00:42 0:01:37 11.5M
31 1550M 31 484M 0 0 11.0M 0 0:02:19 0:00:43 0:01:36 11.5M
31 1550M 31 495M 0 0 11.0M 0 0:02:19 0:00:44 0:01:35 11.4M
32 1550M 32 507M 0 0 11.0M 0 0:02:19 0:00:45 0:01:34 11.4M
33 1550M 33 519M 0 0 11.1M 0 0:02:19 0:00:46 0:01:33 11.4M
34 1550M 34 530M 0 0 11.1M 0 0:02:19 0:00:47 0:01:32 11.3M
34 1550M 34 541M 0 0 11.1M 0 0:02:19 0:00:48 0:01:31 11.4M
35 1550M 35 552M 0 0 11.0M 0 0:02:19 0:00:49 0:01:30 11.2M
36 1550M 36 563M 0 0 11.1M 0 0:02:19 0:00:50 0:01:29 11.2M
37 1550M 37 575M 0 0 11.1M 0 0:02:19 0:00:51 0:01:28 11.2M
37 1550M 37 587M 0 0 11.1M 0 0:02:19 0:00:52 0:01:27 11.2M
38 1550M 38 598M 0 0 11.1M 0 0:02:19 0:00:53 0:01:26 11.4M
39 1550M 39 609M 0 0 11.1M 0 0:02:19 0:00:54 0:01:25 11.4M
40 1550M 40 621M 0 0 11.1M 0 0:02:19 0:00:55 0:01:24 11.4M
40 1550M 40 632M 0 0 11.1M 0 0:02:18 0:00:56 0:01:22 11.4M
41 1550M 41 644M 0 0 11.1M 0 0:02:18 0:00:57 0:01:21 11.4M
42 1550M 42 656M 0 0 11.1M 0 0:02:18 0:00:58 0:01:20 11.5M
43 1550M 43 666M 0 0 11.1M 0 0:02:18 0:00:59 0:01:19 11.4M
43 1550M 43 678M 0 0 11.1M 0 0:02:18 0:01:00 0:01:18 11.4M
44 1550M 44 689M 0 0 11.1M 0 0:02:18 0:01:01 0:01:17 11.3M
45 1550M 45 701M 0 0 11.1M 0 0:02:18 0:01:02 0:01:16 11.3M
45 1550M 45 713M 0 0 11.1M 0 0:02:18 0:01:03 0:01:15 11.3M
46 1550M 46 724M 0 0 11.1M 0 0:02:18 0:01:04 0:01:14 11.5M
47 1550M 47 735M 0 0 11.1M 0 0:02:18 0:01:05 0:01:13 11.4M
48 1550M 48 747M 0 0 11.1M 0 0:02:18 0:01:06 0:01:12 11.4M
48 1550M 48 758M 0 0 11.1M 0 0:02:18 0:01:07 0:01:11 11.3M
49 1550M 49 769M 0 0 11.1M 0 0:02:18 0:01:08 0:01:10 11.2M
50 1550M 50 781M 0 0 11.2M 0 0:02:18 0:01:09 0:01:09 11.3M
51 1550M 51 791M 0 0 11.1M 0 0:02:18 0:01:10 0:01:08 11.2M
51 1550M 51 803M 0 0 11.2M 0 0:02:18 0:01:11 0:01:07 11.3M
52 1550M 52 815M 0 0 11.2M 0 0:02:18 0:01:12 0:01:06 11.4M
53 1550M 53 827M 0 0 11.2M 0 0:02:18 0:01:13 0:01:05 11.5M
54 1550M 54 839M 0 0 11.2M 0 0:02:18 0:01:14 0:01:04 11.6M
54 1550M 54 850M 0 0 11.2M 0 0:02:18 0:01:15 0:01:03 11.7M
55 1550M 55 862M 0 0 11.2M 0 0:02:17 0:01:16 0:01:01 11.7M
56 1550M 56 874M 0 0 11.2M 0 0:02:17 0:01:17 0:01:00 11.7M
57 1550M 57 886M 0 0 11.2M 0 0:02:17 0:01:18 0:00:59 11.7M
57 1550M 57 897M 0 0 11.2M 0 0:02:17 0:01:19 0:00:58 11.6M
58 1550M 58 908M 0 0 11.2M 0 0:02:17 0:01:20 0:00:57 11.5M
59 1550M 59 920M 0 0 11.2M 0 0:02:17 0:01:21 0:00:56 11.5M
60 1550M 60 932M 0 0 11.2M 0 0:02:17 0:01:22 0:00:55 11.5M
60 1550M 60 943M 0 0 11.2M 0 0:02:17 0:01:23 0:00:54 11.5M
61 1550M 61 954M 0 0 11.2M 0 0:02:17 0:01:24 0:00:53 11.2M
62 1550M 62 962M 0 0 11.2M 0 0:02:18 0:01:25 0:00:53 10.7M
62 1550M 62 970M 0 0 11.1M 0 0:02:18 0:01:26 0:00:52 10.0M
63 1550M 63 977M 0 0 11.1M 0 0:02:19 0:01:27 0:00:52 9299k
63 1550M 63 987M 0 0 11.1M 0 0:02:19 0:01:28 0:00:51 9060k
64 1550M 64 998M 0 0 11.1M 0 0:02:19 0:01:29 0:00:50 9213k
65 1550M 65 1010M 0 0 11.1M 0 0:02:19 0:01:30 0:00:49 9827k
65 1550M 65 1019M 0 0 11.1M 0 0:02:19 0:01:31 0:00:48 9.8M
66 1550M 66 1029M 0 0 11.1M 0 0:02:19 0:01:32 0:00:47 10.4M
67 1550M 67 1039M 0 0 11.0M 0 0:02:19 0:01:33 0:00:46 10.3M
67 1550M 67 1050M 0 0 11.0M 0 0:02:19 0:01:34 0:00:45 10.2M
68 1550M 68 1060M 0 0 11.0M 0 0:02:19 0:01:35 0:00:44 10.1M
69 1550M 69 1070M 0 0 11.0M 0 0:02:20 0:01:36 0:00:44 10.1M
69 1550M 69 1080M 0 0 11.0M 0 0:02:20 0:01:37 0:00:43 10.0M
70 1550M 70 1088M 0 0 11.0M 0 0:02:20 0:01:38 0:00:42 9.8M
70 1550M 70 1099M 0 0 11.0M 0 0:02:20 0:01:39 0:00:41 9.8M
71 1550M 71 1109M 0 0 11.0M 0 0:02:20 0:01:40 0:00:40 9927k
72 1550M 72 1118M 0 0 10.9M 0 0:02:21 0:01:41 0:00:40 9909k
72 1550M 72 1128M 0 0 10.9M 0 0:02:21 0:01:42 0:00:39 9929k
73 1550M 73 1139M 0 0 10.9M 0 0:02:21 0:01:43 0:00:38 10.0M
74 1550M 74 1150M 0 0 10.9M 0 0:02:21 0:01:44 0:00:37 10.2M
74 1550M 74 1161M 0 0 10.9M 0 0:02:21 0:01:45 0:00:36 10.4M
75 1550M 75 1172M 0 0 10.9M 0 0:02:21 0:01:46 0:00:35 10.7M
76 1550M 76 1181M 0 0 10.9M 0 0:02:21 0:01:47 0:00:34 10.5M
76 1550M 76 1191M 0 0 10.9M 0 0:02:21 0:01:48 0:00:33 10.4M
77 1550M 77 1202M 0 0 10.9M 0 0:02:21 0:01:49 0:00:32 10.3M
78 1550M 78 1212M 0 0 10.9M 0 0:02:21 0:01:50 0:00:31 10.1M
78 1550M 78 1223M 0 0 10.9M 0 0:02:21 0:01:51 0:00:30 10.2M
79 1550M 79 1232M 0 0 10.9M 0 0:02:21 0:01:52 0:00:29 10.2M
80 1550M 80 1243M 0 0 10.9M 0 0:02:21 0:01:53 0:00:28 10.4M
80 1550M 80 1253M 0 0 10.9M 0 0:02:21 0:01:54 0:00:27 10.3M
81 1550M 81 1265M 0 0 10.9M 0 0:02:21 0:01:55 0:00:26 10.5M
82 1550M 82 1277M 0 0 10.9M 0 0:02:21 0:01:56 0:00:25 10.7M
83 1550M 83 1288M 0 0 10.9M 0 0:02:21 0:01:57 0:00:24 11.1M
83 1550M 83 1298M 0 0 10.9M 0 0:02:21 0:01:58 0:00:23 11.0M
84 1550M 84 1310M 0 0 10.9M 0 0:02:21 0:01:59 0:00:22 11.3M
85 1550M 85 1322M 0 0 10.9M 0 0:02:21 0:02:00 0:00:21 11.3M
86 1550M 86 1334M 0 0 10.9M 0 0:02:21 0:02:01 0:00:20 11.3M
86 1550M 86 1345M 0 0 10.9M 0 0:02:21 0:02:02 0:00:19 11.5M
87 1550M 87 1355M 0 0 10.9M 0 0:02:21 0:02:03 0:00:18 11.3M
88 1550M 88 1365M 0 0 10.9M 0 0:02:21 0:02:04 0:00:17 11.0M
88 1550M 88 1375M 0 0 10.9M 0 0:02:21 0:02:05 0:00:16 10.7M
89 1550M 89 1386M 0 0 10.9M 0 0:02:21 0:02:06 0:00:15 10.4M
90 1550M 90 1397M 0 0 10.9M 0 0:02:21 0:02:07 0:00:14 10.4M
90 1550M 90 1409M 0 0 10.9M 0 0:02:21 0:02:08 0:00:13 10.7M
91 1550M 91 1420M 0 0 10.9M 0 0:02:21 0:02:09 0:00:12 10.9M
92 1550M 92 1430M 0 0 10.9M 0 0:02:21 0:02:10 0:00:11 10.8M
92 1550M 92 1441M 0 0 10.9M 0 0:02:21 0:02:11 0:00:10 10.9M
93 1550M 93 1452M 0 0 10.9M 0 0:02:21 0:02:12 0:00:09 10.8M
94 1550M 94 1463M 0 0 10.9M 0 0:02:21 0:02:13 0:00:08 10.8M
95 1550M 95 1474M 0 0 10.9M 0 0:02:21 0:02:14 0:00:07 10.8M
95 1550M 95 1486M 0 0 10.9M 0 0:02:21 0:02:15 0:00:06 11.3M
96 1550M 96 1496M 0 0 10.9M 0 0:02:21 0:02:16 0:00:05 11.1M
97 1550M 97 1508M 0 0 10.9M 0 0:02:21 0:02:17 0:00:04 11.2M
98 1550M 98 1519M 0 0 10.9M 0 0:02:21 0:02:18 0:00:03 11.3M
98 1550M 98 1532M 0 0 10.9M 0 0:02:21 0:02:19 0:00:02 11.4M
99 1550M 99 1543M 0 0 10.9M 0 0:02:21 0:02:20 0:00:01 11.4M
100 1550M 100 1550M 0 0 10.9M 0 0:02:21 0:02:21 --:--:-- 11.4M
# How many images/classes are there?
walk_through_dir('101_food_classes_10_percent')
There are 2 directories and 0 images in '101_food_classes_10_percent'. There are 101 directories and 0 images in '101_food_classes_10_percent\test'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\foie_gras'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\club_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\cheese_plate'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\cup_cakes'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\garlic_bread'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\gnocchi'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\ice_cream'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\samosa'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\donuts'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\tuna_tartare'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\filet_mignon'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\seaweed_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\french_toast'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\chicken_curry'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\shrimp_and_grits'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\steak'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\cheesecake'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\red_velvet_cake'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\waffles'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\churros'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\gyoza'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\lobster_roll_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\huevos_rancheros'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\breakfast_burrito'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\grilled_cheese_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\spaghetti_bolognese'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\falafel'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\poutine'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\greek_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\beef_tartare'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\fried_calamari'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\guacamole'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\ravioli'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\lobster_bisque'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\beet_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\risotto'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\crab_cakes'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\strawberry_shortcake'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\edamame'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\ceviche'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\hot_and_sour_soup'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\spring_rolls'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\sashimi'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\paella'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\clam_chowder'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\miso_soup'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\escargots'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\hot_dog'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\pulled_pork_sandwich'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\bruschetta'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\panna_cotta'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\fish_and_chips'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\pad_thai'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\tiramisu'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\takoyaki'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\macarons'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\apple_pie'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\cannoli'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\scallops'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\frozen_yogurt'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\chicken_quesadilla'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\mussels'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\beef_carpaccio'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\eggs_benedict'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\spaghetti_carbonara'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\omelette'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\sushi'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\chocolate_mousse'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\beignets'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\bibimbap'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\hummus'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\pork_chop'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\chicken_wings'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\grilled_salmon'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\chocolate_cake'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\tacos'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\hamburger'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\baby_back_ribs'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\pancakes'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\prime_rib'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\pizza'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\nachos'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\macaroni_and_cheese'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\bread_pudding'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\ramen'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\croque_madame'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\lasagna'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\peking_duck'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\deviled_eggs'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\french_fries'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\dumplings'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\fried_rice'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\french_onion_soup'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\pho'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\caprese_salad'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\oysters'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\baklava'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\creme_brulee'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\carrot_cake'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\onion_rings'. There are 0 directories and 250 images in '101_food_classes_10_percent\test\caesar_salad'. There are 101 directories and 0 images in '101_food_classes_10_percent\train'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\foie_gras'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\club_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\cheese_plate'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\cup_cakes'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\garlic_bread'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\gnocchi'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\ice_cream'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\samosa'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\donuts'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\tuna_tartare'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\filet_mignon'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\seaweed_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\french_toast'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\chicken_curry'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\shrimp_and_grits'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\steak'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\cheesecake'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\red_velvet_cake'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\waffles'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\churros'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\gyoza'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\lobster_roll_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\huevos_rancheros'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\breakfast_burrito'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\grilled_cheese_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\spaghetti_bolognese'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\falafel'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\poutine'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\greek_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\beef_tartare'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\fried_calamari'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\guacamole'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\ravioli'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\lobster_bisque'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\beet_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\risotto'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\crab_cakes'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\strawberry_shortcake'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\edamame'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\ceviche'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\hot_and_sour_soup'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\spring_rolls'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\sashimi'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\paella'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\clam_chowder'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\miso_soup'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\escargots'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\hot_dog'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\pulled_pork_sandwich'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\bruschetta'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\panna_cotta'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\fish_and_chips'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\pad_thai'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\tiramisu'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\takoyaki'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\macarons'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\apple_pie'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\cannoli'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\scallops'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\frozen_yogurt'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\chicken_quesadilla'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\mussels'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\beef_carpaccio'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\eggs_benedict'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\spaghetti_carbonara'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\omelette'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\sushi'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\chocolate_mousse'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\beignets'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\bibimbap'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\hummus'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\pork_chop'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\chicken_wings'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\grilled_salmon'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\chocolate_cake'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\tacos'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\hamburger'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\baby_back_ribs'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\pancakes'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\prime_rib'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\pizza'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\nachos'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\macaroni_and_cheese'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\bread_pudding'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\ramen'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\croque_madame'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\lasagna'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\peking_duck'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\deviled_eggs'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\french_fries'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\dumplings'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\fried_rice'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\french_onion_soup'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\pho'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\caprese_salad'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\oysters'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\baklava'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\creme_brulee'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\carrot_cake'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\onion_rings'. There are 0 directories and 75 images in '101_food_classes_10_percent\train\caesar_salad'.
As before, our data is structured in the following:
10_food_classes_10_percent <- top level folder
└───train <- training images
│ └───pizza
│ │ │ 1008104.jpg
│ │ │ 1638227.jpg
│ │ │ ...
│ └───steak
│ │ 1000205.jpg
│ │ 1647351.jpg
│ │ ...
│
└───test <- testing images
│ └───pizza
│ │ │ 1001116.jpg
│ │ │ 1507019.jpg
│ │ │ ...
│ └───steak
│ │ 100274.jpg
│ │ 1653815.jpg
│ │ ...
Let's use the image_dataset_from_directory() function to turn our images and labels into a tf.data.Dataset. A TensorFlow datatype, allowing us to pass it a directory to our model.
For the test dataset, we're going to set shuffle=False, so we can perform repeatable evaluation and visualization on it later.
# setup data inputs
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_all_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
label_mode='categorical',
image_size=IMG_SIZE)
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode='categorical',
image_size=IMG_SIZE,
shuffle=False) # don't shuffle to keep experiments repeatable
Found 7575 files belonging to 101 classes. Found 25250 files belonging to 101 classes.
Train a big dog model with transfer learning on 10% of 101 food classes¶
To keep experiments swift, we're going to start by using feature extraction transfer learning with a pre-trained model for a few epochs, then fine-tune it for a few more epochs.
Our goal is to see if we can beat base line from the original Food101 paper (accuracy 50.76%), while using only 10% of data.
- A
ModelCheckpointcallback to save progress during training, meaning we can further experiment with further training later without having to train from scratch every time - Data augmentation built right into the model
- A headless (no top layers)
EfficientNetB0architecture fromtf.keras.applicationsas our base model - A
Denselayer with 101 hidden neurons (same as number of food classes) and softmax activation as the output layer - Categorical crossentropy as the loss function since we're dealing with more than two classes
- The Adam optimizer with the default settings
- Fitting for 5 full passes on the training data while evaluating on 15% of the test data
It seems like a lot, but these are all things covered before in part 2, from workbook 05.
Let's start with creating ModelCheckpoint callback.
Since we want the model to perform well on unseen data, we'll set it to monitor validation accuracy metric, and save model weights on the one that had the best score on said metric.
# create checkpoint callback to save model for later use
checkpoint_path = '101_classes_10_percent_data_model_checkpoint.weights.h5'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, # save only the model weights
monitor='val_accuracy', # monitor val accuracy, to determine the weight saved, based on its resuls
save_best_only=True) # only keep the best resulting weights, and discard the rest
Checkpoint is now ready. Let's create a small data augmentation model with sequential API. Due to reduced training data size, this will help prevent overfitting.
# import the required modules for model creation
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
data_augmentation = Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomHeight(0.2),
layers.RandomWidth(0.2),
# preprocessing.Rescaling(1./255) # keep for ResNet40V2, remove for EfficientNetB0 as they already automate this process
], name='data_augmentation')
We'll be able to insert data_augmentation Sequential model, represented as a layer in a Functional API model. So if we want to further train a model another time, this sequential model will already be implemented in the functional API.
Now time to put it together, experimenting with feature extraction transfer learning model, using tf.keras.applications.efficientnet.EfficientNetB0 as the base model.
We'll import the base model using the parameter include_top=False, so we can add our own output layer, notably GlobalAveragePooling2D(). It condenses the output of base model into a 1D vector, which is a usable shape for the output layer, followed by a Dense layer.
# setup base model and freeze its layers (this will extract features)
base_model = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet', input_shape=(224,224,3))
base_model.trainable = False
# setup model architevture with trainable top layers
inputs = layers.Input(shape=(224,224,3), name='input_layer') # shape of input image
x = data_augmentation(inputs) # augment images (will only happen during training)
x = base_model(x, training=False) # put the base model in inference mode (training=False) so we can use it to extract features, and not update weights
x = layers.GlobalAveragePooling2D(name='global_average_pooling')(x) # pool the outputs of the base model
outputs = layers.Dense(len(train_data_all_10_percent.class_names), activation='softmax', name='output_layer')(x) # same number of outputs as classes
model = tf.keras.Model(inputs, outputs)
An illustrated sigure, depicting how ourt model looks like visually in its order
Let's inspect the model
# get a summary of the model
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 224, 224, 3)] 0
data_augmentation (Sequent (None, None, None, 3) 0
ial)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 224, 224, 3)] 0
data_augmentation (Sequent (None, None, None, 3) 0
ial)
efficientnetb0 (Functional (None, 7, 7, 1280) 4049571
)
global_average_pooling (Gl (None, 1280) 0
obalAveragePooling2D)
output_layer (Dense) (None, 101) 129381
=================================================================
Total params: 4178952 (15.94 MB)
Trainable params: 129381 (505.39 KB)
Non-trainable params: 4049571 (15.45 MB)
_________________________________________________________________
Nice, our functional model represents 5 layers, but each layer likely has their own layers that vary from the other.
If you notice the difference between Trainable and Non-trainable parameters, Trainable parameter only encompasses the output_layer, while base model efficientnetb0 is frozen. We're initially running feature extraction, where we keep the learned patterns of base model frozen, whilst letting output layer adjust and tune based on our custom data.
Time to compile and fit.
# compile
model.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(), # using Adam's default learning rate
metrics=['accuracy'])
# fit
history_all_classes_10_percent = model.fit(train_data_all_10_percent,
epochs=5, # fitting 5 epochs to keep experiments short
validation_data=test_data,
validation_steps=int(0.15 * len(test_data)), # evaluate on 15% of test data, again to keep experiments quick
callbacks=[checkpoint_callback]) # callback to save best model weights on file
Epoch 1/5 237/237 [==============================] - 328s 1s/step - loss: 3.3825 - accuracy: 0.2684 - val_loss: 2.4930 - val_accuracy: 0.4396 Epoch 2/5 237/237 [==============================] - 297s 1s/step - loss: 2.2078 - accuracy: 0.4939 - val_loss: 2.0097 - val_accuracy: 0.5254 Epoch 3/5 237/237 [==============================] - 292s 1s/step - loss: 1.8206 - accuracy: 0.5706 - val_loss: 1.8636 - val_accuracy: 0.5344 Epoch 4/5 237/237 [==============================] - 293s 1s/step - loss: 1.6076 - accuracy: 0.6136 - val_loss: 1.8080 - val_accuracy: 0.5384 Epoch 5/5 237/237 [==============================] - 291s 1s/step - loss: 1.4556 - accuracy: 0.6437 - val_loss: 1.7518 - val_accuracy: 0.5445
It seems the model has done impressive results, but it's only being evaluated with 15% of the test data. Let's expand it to the full test dataset.
# evaluate model
results_feature_extraction_model = model.evaluate(test_data)
results_feature_extraction_model
790/790 [==============================] - 635s 803ms/step - loss: 1.5837 - accuracy: 0.5846
[1.5837445259094238, 0.5845940709114075]
Well, it looks as if we've just beaten the Food101 paper with 10% of the data! That's the strength of deep larning, more precisely transfer learning. Leveraging what a model has learned, into another data set.
How do the loss curves look?
plot_loss_curves(history_all_classes_10_percent)
Question: What should we expect the curves to suggest? Ideally, we want both curves to follow similarly to each other. If there are diversions between the two, there may be issues with overfitting or underfitting.
Fine tuning¶
Our fecture extraction transfer learning model is performing well. Why not try fine-tune a few layers in the base model, and see if improvements can be gained?
With ModelCheckpoint callback, we have the saved weights of our current performing model. So if fine tuning doesn't offer benefits to us, we can revert back to it's previous status.
To fine tune, setting trainable to True is needed on base model.
Due to our small training dataset (on purpose), we'll refreeze the model except for the last 5 layers, making them trainable.
# unfreeze the base model
base_model.trainable = True
# refreeze layers except the last 5 layers
for layer in base_model.layers[:-5]:
layer.trainable = False
Now that changes were made to the functional api model, we need to recompile it to truly save the changes of the model.
Because of fine-tuning, learning rate will be lowered 10x, to ensure updates are minimal as to not heavily disturb the weights, that have been calibrated for such problems.

When fine-tuning and unfreezing layers of your pre trained model, it's common practice to lower learning rate down 10 times.
# recompile model with lower learning rate
model.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(1e-4), # lower learning rate by 10 fold from the default
metrics=['accuracy'])
Model recompiled. Let's check what layers are trainable
# what layers in the model are trainable?
for layer in model.layers:
print(layer.name, layer.trainable)
input_layer True data_augmentation True efficientnetb0 True global_average_pooling True output_layer True
# check which layers are trainable
for layer_number, layer in enumerate(base_model.layers):
print(layer_number, layer.name, layer.trainable)
0 input_3 False 1 rescaling_2 False 2 normalization_2 False 3 rescaling_3 False 4 stem_conv_pad False 5 stem_conv False 6 stem_bn False 7 stem_activation False 8 block1a_dwconv False 9 block1a_bn False 10 block1a_activation False 11 block1a_se_squeeze False 12 block1a_se_reshape False 13 block1a_se_reduce False 14 block1a_se_expand False 15 block1a_se_excite False 16 block1a_project_conv False 17 block1a_project_bn False 18 block2a_expand_conv False 19 block2a_expand_bn False 20 block2a_expand_activation False 21 block2a_dwconv_pad False 22 block2a_dwconv False 23 block2a_bn False 24 block2a_activation False 25 block2a_se_squeeze False 26 block2a_se_reshape False 27 block2a_se_reduce False 28 block2a_se_expand False 29 block2a_se_excite False 30 block2a_project_conv False 31 block2a_project_bn False 32 block2b_expand_conv False 33 block2b_expand_bn False 34 block2b_expand_activation False 35 block2b_dwconv False 36 block2b_bn False 37 block2b_activation False 38 block2b_se_squeeze False 39 block2b_se_reshape False 40 block2b_se_reduce False 41 block2b_se_expand False 42 block2b_se_excite False 43 block2b_project_conv False 44 block2b_project_bn False 45 block2b_drop False 46 block2b_add False 47 block3a_expand_conv False 48 block3a_expand_bn False 49 block3a_expand_activation False 50 block3a_dwconv_pad False 51 block3a_dwconv False 52 block3a_bn False 53 block3a_activation False 54 block3a_se_squeeze False 55 block3a_se_reshape False 56 block3a_se_reduce False 57 block3a_se_expand False 58 block3a_se_excite False 59 block3a_project_conv False 60 block3a_project_bn False 61 block3b_expand_conv False 62 block3b_expand_bn False 63 block3b_expand_activation False 64 block3b_dwconv False 65 block3b_bn False 66 block3b_activation False 67 block3b_se_squeeze False 68 block3b_se_reshape False 69 block3b_se_reduce False 70 block3b_se_expand False 71 block3b_se_excite False 72 block3b_project_conv False 73 block3b_project_bn False 74 block3b_drop False 75 block3b_add False 76 block4a_expand_conv False 77 block4a_expand_bn False 78 block4a_expand_activation False 79 block4a_dwconv_pad False 80 block4a_dwconv False 81 block4a_bn False 82 block4a_activation False 83 block4a_se_squeeze False 84 block4a_se_reshape False 85 block4a_se_reduce False 86 block4a_se_expand False 87 block4a_se_excite False 88 block4a_project_conv False 89 block4a_project_bn False 90 block4b_expand_conv False 91 block4b_expand_bn False 92 block4b_expand_activation False 93 block4b_dwconv False 94 block4b_bn False 95 block4b_activation False 96 block4b_se_squeeze False 97 block4b_se_reshape False 98 block4b_se_reduce False 99 block4b_se_expand False 100 block4b_se_excite False 101 block4b_project_conv False 102 block4b_project_bn False 103 block4b_drop False 104 block4b_add False 105 block4c_expand_conv False 106 block4c_expand_bn False 107 block4c_expand_activation False 108 block4c_dwconv False 109 block4c_bn False 110 block4c_activation False 111 block4c_se_squeeze False 112 block4c_se_reshape False 113 block4c_se_reduce False 114 block4c_se_expand False 115 block4c_se_excite False 116 block4c_project_conv False 117 block4c_project_bn False 118 block4c_drop False 119 block4c_add False 120 block5a_expand_conv False 121 block5a_expand_bn False 122 block5a_expand_activation False 123 block5a_dwconv False 124 block5a_bn False 125 block5a_activation False 126 block5a_se_squeeze False 127 block5a_se_reshape False 128 block5a_se_reduce False 129 block5a_se_expand False 130 block5a_se_excite False 131 block5a_project_conv False 132 block5a_project_bn False 133 block5b_expand_conv False 134 block5b_expand_bn False 135 block5b_expand_activation False 136 block5b_dwconv False 137 block5b_bn False 138 block5b_activation False 139 block5b_se_squeeze False 140 block5b_se_reshape False 141 block5b_se_reduce False 142 block5b_se_expand False 143 block5b_se_excite False 144 block5b_project_conv False 145 block5b_project_bn False 146 block5b_drop False 147 block5b_add False 148 block5c_expand_conv False 149 block5c_expand_bn False 150 block5c_expand_activation False 151 block5c_dwconv False 152 block5c_bn False 153 block5c_activation False 154 block5c_se_squeeze False 155 block5c_se_reshape False 156 block5c_se_reduce False 157 block5c_se_expand False 158 block5c_se_excite False 159 block5c_project_conv False 160 block5c_project_bn False 161 block5c_drop False 162 block5c_add False 163 block6a_expand_conv False 164 block6a_expand_bn False 165 block6a_expand_activation False 166 block6a_dwconv_pad False 167 block6a_dwconv False 168 block6a_bn False 169 block6a_activation False 170 block6a_se_squeeze False 171 block6a_se_reshape False 172 block6a_se_reduce False 173 block6a_se_expand False 174 block6a_se_excite False 175 block6a_project_conv False 176 block6a_project_bn False 177 block6b_expand_conv False 178 block6b_expand_bn False 179 block6b_expand_activation False 180 block6b_dwconv False 181 block6b_bn False 182 block6b_activation False 183 block6b_se_squeeze False 184 block6b_se_reshape False 185 block6b_se_reduce False 186 block6b_se_expand False 187 block6b_se_excite False 188 block6b_project_conv False 189 block6b_project_bn False 190 block6b_drop False 191 block6b_add False 192 block6c_expand_conv False 193 block6c_expand_bn False 194 block6c_expand_activation False 195 block6c_dwconv False 196 block6c_bn False 197 block6c_activation False 198 block6c_se_squeeze False 199 block6c_se_reshape False 200 block6c_se_reduce False 201 block6c_se_expand False 202 block6c_se_excite False 203 block6c_project_conv False 204 block6c_project_bn False 205 block6c_drop False 206 block6c_add False 207 block6d_expand_conv False 208 block6d_expand_bn False 209 block6d_expand_activation False 210 block6d_dwconv False 211 block6d_bn False 212 block6d_activation False 213 block6d_se_squeeze False 214 block6d_se_reshape False 215 block6d_se_reduce False 216 block6d_se_expand False 217 block6d_se_excite False 218 block6d_project_conv False 219 block6d_project_bn False 220 block6d_drop False 221 block6d_add False 222 block7a_expand_conv False 223 block7a_expand_bn False 224 block7a_expand_activation False 225 block7a_dwconv False 226 block7a_bn False 227 block7a_activation False 228 block7a_se_squeeze False 229 block7a_se_reshape False 230 block7a_se_reduce False 231 block7a_se_expand False 232 block7a_se_excite False 233 block7a_project_conv True 234 block7a_project_bn True 235 top_conv True 236 top_bn True 237 top_activation True
Nice, time to fine tune the model.
Another 5 epochs should be enough to see if it benefits the model or not, though more epochs won't hurt as well.
We'll start the training, where feature extraction has left us off using the initial_epoch parameter in the fit() function.
# fine tune 5 more epochs
fine_tune_epochs = 10 # model has done 5 epochs. The is the total epochs we're after (5 initially, and another 5)
history_all_classes_10_percent_fine_tune = model.fit(train_data_all_10_percent,
epochs=fine_tune_epochs,
validation_data=test_data,
validation_steps=int(0.15*len(test_data)),
initial_epoch=history_all_classes_10_percent.epoch[-1]) # start from previous last epoch
Epoch 5/10 237/237 [==============================] - 355s 1s/step - loss: 1.2106 - accuracy: 0.6825 - val_loss: 1.6916 - val_accuracy: 0.5559 Epoch 6/10 237/237 [==============================] - 333s 1s/step - loss: 1.0890 - accuracy: 0.7067 - val_loss: 1.7174 - val_accuracy: 0.5493 Epoch 7/10 237/237 [==============================] - 327s 1s/step - loss: 1.0212 - accuracy: 0.7298 - val_loss: 1.6693 - val_accuracy: 0.5612 Epoch 8/10 237/237 [==============================] - 331s 1s/step - loss: 0.9446 - accuracy: 0.7440 - val_loss: 1.7185 - val_accuracy: 0.5519 Epoch 9/10 237/237 [==============================] - 332s 1s/step - loss: 0.8885 - accuracy: 0.7625 - val_loss: 1.7214 - val_accuracy: 0.5551 Epoch 10/10 237/237 [==============================] - 331s 1s/step - loss: 0.8269 - accuracy: 0.7754 - val_loss: 1.7270 - val_accuracy: 0.5506
Once again, we're only evaluating on a small portion of the test data. Let's find out how well the model does on full test data.
# evaluate fine-tuned model on the whole test dataset
results_all_classes_10_percent_fine_tune = model.evaluate(test_data)
results_all_classes_10_percent_fine_tune
790/790 [==============================] - 684s 865ms/step - loss: 1.5024 - accuracy: 0.6017
[1.5023630857467651, 0.60166335105896]
It seems like there's minimal improvements unfortunately.
We may get a better idea by using compare_historys() function, and seeing how the training curves look like on the graph.
compare_historys(original_history=history_all_classes_10_percent,
new_history=history_all_classes_10_percent_fine_tune,
initial_epochs=5)
With fine-tuning, we can see how on training data, the accuracy has improved significantly and its trend seems to show it will continue to increase accuracy. Though validation data has not reflected this trend, and barely kept up. It's showing classic signs of overfitting.
This is fine though, as fine-tuning often leads to overfitting of training data, due to the pre-trained model, having been already been trained on similar data to our custom problem.
For our case, the pre-trained model EfficientNetB0 was trained on ImageNet. Which contained many real life photos of food, just like our food dataset.
If feature-extraction already works well, fine-tuning may not be as nessecary for further improvement. Dataset with images that differ away from the pre-trained model's data, will typically benefit with fine-tuning.
Saving our trained model¶
To prevent the need to retrain the model again, let's save it physically with the save() method.
# Save model to computer for later use if needed
model.save('101_food_class_10_percent_saved_big_dog_model')
INFO:tensorflow:Assets written to: 101_food_class_10_percent_saved_big_dog_model\assets
INFO:tensorflow:Assets written to: 101_food_class_10_percent_saved_big_dog_model\assets
Evaluating the performance of the big dog model across all different classes¶
We've got a trained and saved model, whigh has done fairly well from evaluation on test dataset.
Let's go deeper into model's performance and get some visualizations going.
For that, we can load the saved model and use it to make some predictions on the test dataset.
Note: Evaluating a ML model is as important as training one. Final metrics can be decieving. Always visualize the model's performance on unseen data to make sure you aren't being fooled with good looking numbers.
import tensorflow as tf
# download pre-trained model from google storage
!curl -O https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
saved_model_path = '06_101_food_class_10_percent_saved_big_dog_model.zip'
unzip_data(saved_model_path)
model = tf.keras.models.load_model(saved_model_path.split('.')[0])
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
0 44.5M 0 174k 0 0 152k 0 0:04:58 0:00:01 0:04:57 153k
1 44.5M 1 812k 0 0 542k 0 0:01:24 0:00:01 0:01:23 543k
24 44.5M 24 10.7M 0 0 4396k 0 0:00:10 0:00:02 0:00:08 4401k
49 44.5M 49 22.2M 0 0 6503k 0 0:00:07 0:00:03 0:00:04 6508k
74 44.5M 74 33.0M 0 0 7536k 0 0:00:06 0:00:04 0:00:02 7540k
97 44.5M 97 43.6M 0 0 8140k 0 0:00:05 0:00:05 --:--:-- 9.9M
100 44.5M 100 44.5M 0 0 8195k 0 0:00:05 0:00:05 --:--:-- 10.7M
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_419470) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_446460) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_450449) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_415747) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_416083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_450775) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_451847) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_417915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_451887) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_452467) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_438312) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_417583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_418582) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_454031) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_455436) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_415524) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_451474) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_451768) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_441729) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_454357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_416695) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_454238) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_436681) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_415804) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_452919) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_453658) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_448082) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_418915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_453539) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_452586) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_450163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_418018) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_455357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_417639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_451188) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_420190) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_415468) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_455476) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_417354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_452213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_452173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_415571) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_451514) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_417971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_454730) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_416742) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_450489) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_451148) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_418194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_416463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_429711) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_443351) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_418526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_453245) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_416416) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_428089) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_416027) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_453912) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_452546) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_420237) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_418629) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_416359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_451395) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_454690) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_419905) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_419526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_418297) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_452094) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference__wrapped_model_408990) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_453618) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_454984) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_450696) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_418858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_450044) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_418250) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_453991) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_453285) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_416971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_455683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_415851) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_453166) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_420413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_450123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_417075) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_452840) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_417307) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_455063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_419802) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_419858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_452959) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_451069) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_450370) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_419138) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_419194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_419573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_420134) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_417028) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_454611) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_416639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_417686) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_417251) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_455103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_450815) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_416130) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_454317) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_418962) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_419241) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
To make sure our loaded model is indeed a trained model, let's evaluate its performance on the test dataset.
# check to see if loaded model is a trained model
loaded_loss, loaded_accuracy = model.evaluate(test_data)
loaded_loss, loaded_accuracy
790/790 [==============================] - 711s 898ms/step - loss: 1.8027 - accuracy: 0.6078
(1.8027207851409912, 0.6077623963356018)
Looks like the loaded model is doing just as well as when it was during evaluation of full test data.
Making predictions with our trained model¶
To evaluate the trained model, we need to make some predictions with it, and compre the predictions to the test dataset.
As the model has never seen the test dataset, this should give us an indication of how the model will perform in the real world, on similar data that it has been trained on.
To make predictions on the model, we use predict() when passing on the test data.
# make predictions with model
pred_probs = model.predict(test_data, verbose=1) # set verbosity to see how long it will take
790/790 [==============================] - 661s 834ms/step
We just passed all test images to the model, and asked for it to make a prediction per image of food.
So how many predictions are made?
# how many predictions are there?
len(pred_probs)
25250
with each image, possibly being out of 101 classes, thats 25,250 images, with 101 possibilities.
# what's the shape of our predictions?
pred_probs.shape
(25250, 101)
What we have is called a predictions probability tensor (or array).
Let's see what the first 10 looks like
pred_probs[:10]
array([[5.9541803e-02, 3.5742237e-06, 4.1377347e-02, ..., 1.4138752e-09,
8.3531268e-05, 3.0897614e-03],
[9.6401668e-01, 1.3753035e-09, 8.4780005e-04, ..., 5.4287146e-05,
7.8361458e-12, 9.8464892e-10],
[9.5925862e-01, 3.2533895e-05, 1.4867117e-03, ..., 7.1892083e-07,
5.4396531e-07, 4.0275998e-05],
...,
[4.7313362e-01, 1.2931211e-07, 1.4805610e-03, ..., 5.9750286e-04,
6.6968983e-05, 2.3469302e-05],
[4.4571716e-02, 4.7265306e-07, 1.2258544e-01, ..., 6.3498819e-06,
7.5318171e-06, 3.6778876e-03],
[7.2438931e-01, 1.9249797e-09, 5.2311167e-05, ..., 1.2291438e-03,
1.5792799e-09, 9.6395903e-05]], dtype=float32)
It looks to be a bunch of tensors of really small numbers. How about we zoom into one of the tensor?
# we get one prediction probability per class, for each image
print(f'Number of prediction probabilities for sample 0: {len(pred_probs[0])}')
print(f'What prediction probability sample 0 looks like: \n {pred_probs[0]}')
print(f'The class with the highest predicted probability by the mode for sample 0: {pred_probs[0].argmax()}')
Number of prediction probabilities for sample 0: 101 What prediction probability sample 0 looks like: [5.95418029e-02 3.57422368e-06 4.13773470e-02 1.06606712e-09 8.16151680e-09 8.66406058e-09 8.09273104e-07 8.56533518e-07 1.98592134e-05 8.09785718e-07 3.17280868e-09 9.86751957e-07 2.85323913e-04 7.80500442e-10 7.42301228e-04 3.89165434e-05 6.47412026e-06 2.49774348e-06 3.78915465e-05 2.06784350e-07 1.55385478e-05 8.15079147e-07 2.62307526e-06 2.00107877e-07 8.38284564e-07 5.42161115e-06 3.73912280e-06 1.31505269e-08 2.77616014e-03 2.80517943e-05 6.85629054e-10 2.55749892e-05 1.66890954e-04 7.64081243e-10 4.04532795e-04 1.31507765e-08 1.79575227e-06 1.44483045e-06 2.30629761e-02 8.24671304e-07 8.53669519e-07 1.71386864e-06 7.05258026e-06 1.84024014e-08 2.85536885e-07 7.94840162e-06 2.06818777e-06 1.85252830e-07 3.36200756e-08 3.15226294e-04 1.04110168e-05 8.54497102e-07 8.47418189e-01 1.05554800e-05 4.40948554e-07 3.74044794e-05 3.53065443e-05 3.24891153e-05 6.73152244e-05 1.28526594e-08 2.62199956e-10 1.03182419e-05 8.57441555e-05 1.05699201e-06 2.12935470e-06 3.76377102e-05 7.59745546e-08 2.53405946e-04 9.29065152e-07 1.25982158e-04 6.26223436e-06 1.24587913e-08 4.05197461e-05 6.87283404e-08 1.25463464e-06 5.28879660e-08 7.54253193e-08 7.53988934e-05 7.75409208e-05 6.40267046e-07 9.90336275e-07 2.22261660e-05 1.50140704e-05 1.40385367e-07 1.22326192e-05 1.90447737e-02 5.00000533e-05 4.62264643e-06 1.53884358e-07 3.38243041e-07 3.92285360e-09 1.65638838e-07 8.13211809e-05 4.89655076e-06 2.40683391e-07 2.31242102e-05 3.10408417e-04 3.13802557e-05 1.41387524e-09 8.35312676e-05 3.08976136e-03] The class with the highest predicted probability by the mode for sample 0: 52
For every image tensor we pass to the model, due to the ouput number of neurons, and the chosen activation function of the last layer, (layers.Dense(len(train_data_all_10_percent.class_names), activation='softmax')) it outputs a prediction probability between 0 and 1, for all 101 classes.
You can consider the index with the highest value in prediction probability to be what the model thinks is the most likely label.
Note: The nature of softmax, is that there is a value of '1', and must be distributed out to all classes (aka all classes sum up to 1).
We can find the indexed class with the highest value, using argmax() method.
# get the class predictions of each label
pred_classes = pred_probs.argmax(axis=1)
# how do they look?
pred_classes[:10]
array([52, 0, 0, 80, 79, 61, 29, 0, 85, 0], dtype=int64)
Now we got the predicted class index for each of the samples in our test dataset :)
We'll be able to compare them to the test dataset labels, and further evaluate model.
To get test dataset labels, we'll need to unravel test_data object (which is a form of tf.data.Dataset) using the unbatch() method.
Doing so will give us access to the images and labels. The labels are one-hot encoded, making the argmax() method useful to find the indexed location of the class.
Note: This is why
shuffle=Falseis essential forour test dataset. If it shuffles everytime we use it, the location of say,image[0], will be in a completely different location, making it impossible to compare.
# Note: this might take a minute or so due to unravelling 790 batches
y_labels = []
for images, labels in test_data.unbatch(): # unbatch the test data and get images and labels
y_labels.append(labels.numpy().argmax()) # append the index which has largest value
y_labels[:10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
The final check is to see how many labels we have.
# how many labels are there?
len(y_labels)
25250
As expected, the number of labels match the number of images. Time to compare the prediction to the labels.
Evaluating our models predictions¶
A very simple evaluation is to use Scikit-learn's accuracy_score() function which compares truth labels to predicted labels and returns an accuracy score.
If both datasets are correct, we should have roughly the same accuracy value as when we did the .evaluate() method earlier.
# get accuracy score by comparing predicted clsasses to ground truth labels
from sklearn.metrics import accuracy_score
sklearn_accuracy = accuracy_score(y_labels,pred_classes)
sklearn_accuracy
0.6077623762376237
# does the .evaluate() value match closely to the value above?
import numpy as np
print(f'Close? {np.isclose(loaded_accuracy, sklearn_accuracy)} | Difference: {loaded_accuracy - sklearn_accuracy}')
It looks like the orders of both dataset is correct.
How about we visualize this in a confusion matrix? We'll make use of make_confusion_matrix function from the helper function.
# import confusion matrix from helper function
from helper_functions import make_confusion_matrix
# note: the following confusion matrix code is a remix of scikit-learn's plot_confusion_matrix function
import itertools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# Get the class names to input for confusion matrix
class_names = test_data.class_names
# Plot a confusion matrix with all 25250 predictions, ground truth labels and 101 classes
make_confusion_matrix(y_true=y_labels,
y_pred=pred_classes,
classes=class_names,
figsize=(100, 100),
text_size=20,
norm=False,
savefig=True)
This is a very big confusion matrix. It may look daunting at first, but zooming in a lot can say a lot of insight on which classes get 'confused' on the most, and which ones they are often predicting.
Good news is the majority line up diagonally from top left to bottom right. Indicating they're matching to the exact class in both x and y axis.
It seems the model gets confused most often on visually similar foods. Like fillet_mignon to pork_chop, or chocolate_cake to tiramisu.
Since this is a classification problem, we can further evaluate the model's predictions using Scikit-Learn's classification_report() function.
from sklearn.metrics import classification_report
print(classification_report(y_labels, pred_classes))
precision recall f1-score support
0 0.29 0.20 0.24 250
1 0.51 0.69 0.59 250
2 0.56 0.65 0.60 250
3 0.74 0.53 0.62 250
4 0.73 0.43 0.54 250
5 0.34 0.54 0.42 250
6 0.67 0.79 0.72 250
7 0.82 0.76 0.79 250
8 0.40 0.37 0.39 250
9 0.62 0.44 0.51 250
10 0.62 0.42 0.50 250
11 0.84 0.49 0.62 250
12 0.52 0.74 0.61 250
13 0.56 0.60 0.58 250
14 0.56 0.59 0.57 250
15 0.44 0.32 0.37 250
16 0.45 0.75 0.57 250
17 0.37 0.51 0.43 250
18 0.43 0.60 0.50 250
19 0.68 0.60 0.64 250
20 0.68 0.75 0.71 250
21 0.35 0.64 0.45 250
22 0.30 0.37 0.33 250
23 0.66 0.77 0.71 250
24 0.83 0.72 0.77 250
25 0.76 0.71 0.73 250
26 0.51 0.42 0.46 250
27 0.78 0.72 0.75 250
28 0.70 0.69 0.69 250
29 0.70 0.68 0.69 250
30 0.92 0.63 0.75 250
31 0.78 0.70 0.74 250
32 0.75 0.83 0.79 250
33 0.89 0.98 0.94 250
34 0.68 0.78 0.72 250
35 0.78 0.66 0.72 250
36 0.53 0.56 0.55 250
37 0.30 0.55 0.39 250
38 0.78 0.63 0.69 250
39 0.27 0.33 0.30 250
40 0.72 0.81 0.76 250
41 0.81 0.62 0.70 250
42 0.50 0.58 0.54 250
43 0.75 0.60 0.67 250
44 0.74 0.45 0.56 250
45 0.77 0.85 0.81 250
46 0.81 0.46 0.58 250
47 0.44 0.49 0.46 250
48 0.45 0.81 0.58 250
49 0.50 0.44 0.47 250
50 0.54 0.39 0.46 250
51 0.71 0.86 0.78 250
52 0.51 0.77 0.61 250
53 0.67 0.68 0.68 250
54 0.88 0.75 0.81 250
55 0.86 0.69 0.76 250
56 0.56 0.24 0.34 250
57 0.62 0.45 0.52 250
58 0.68 0.58 0.62 250
59 0.70 0.37 0.49 250
60 0.83 0.59 0.69 250
61 0.54 0.81 0.65 250
62 0.72 0.49 0.58 250
63 0.94 0.86 0.90 250
64 0.78 0.85 0.81 250
65 0.82 0.82 0.82 250
66 0.69 0.32 0.44 250
67 0.41 0.58 0.48 250
68 0.90 0.78 0.83 250
69 0.84 0.82 0.83 250
70 0.62 0.83 0.71 250
71 0.81 0.46 0.59 250
72 0.64 0.65 0.65 250
73 0.51 0.44 0.47 250
74 0.72 0.61 0.66 250
75 0.84 0.90 0.87 250
76 0.78 0.78 0.78 250
77 0.36 0.27 0.31 250
78 0.79 0.74 0.76 250
79 0.44 0.81 0.57 250
80 0.57 0.60 0.59 250
81 0.65 0.70 0.68 250
82 0.38 0.31 0.34 250
83 0.58 0.80 0.67 250
84 0.61 0.38 0.47 250
85 0.44 0.74 0.55 250
86 0.71 0.86 0.78 250
87 0.41 0.39 0.40 250
88 0.83 0.80 0.81 250
89 0.71 0.31 0.43 250
90 0.92 0.69 0.79 250
91 0.83 0.87 0.85 250
92 0.68 0.65 0.67 250
93 0.31 0.38 0.34 250
94 0.61 0.54 0.57 250
95 0.74 0.61 0.67 250
96 0.56 0.29 0.38 250
97 0.45 0.74 0.56 250
98 0.47 0.33 0.39 250
99 0.52 0.27 0.35 250
100 0.59 0.70 0.64 250
accuracy 0.61 25250
macro avg 0.63 0.61 0.61 25250
weighted avg 0.63 0.61 0.61 25250
The classification_report() outputs precision, recall, and f1-scores per class.
A refresher:
- Precision: Proportion of true positives, over the total number of positive samples being predicted. Higher precision means less false positives (real answer = 0, but predicted 1).
- Recall: Proportion of true positives, over the total number of positives that are in the dataset. Higher recall means less false negatives (real answer = 1, but predicted 0).
- f1-score: Combines precision and recall into one metric. Higher the score, more accurate the model is.
The above output is helpful, but hard to understand with so much classes that it gets hard to understand.
Let's see if we can make it easire through visualization.
We'll get classification_report() as a dictionary using output_dict=True.
# get a dictionary for the classification report
classification_report_dict = classification_report(y_labels, pred_classes, output_dict=True)
classification_report_dict
{'0': {'precision': 0.29310344827586204,
'recall': 0.204,
'f1-score': 0.24056603773584906,
'support': 250.0},
'1': {'precision': 0.5088235294117647,
'recall': 0.692,
'f1-score': 0.5864406779661017,
'support': 250.0},
'2': {'precision': 0.5625,
'recall': 0.648,
'f1-score': 0.6022304832713755,
'support': 250.0},
'3': {'precision': 0.7415730337078652,
'recall': 0.528,
'f1-score': 0.616822429906542,
'support': 250.0},
'4': {'precision': 0.7346938775510204,
'recall': 0.432,
'f1-score': 0.5440806045340051,
'support': 250.0},
'5': {'precision': 0.34177215189873417,
'recall': 0.54,
'f1-score': 0.4186046511627907,
'support': 250.0},
'6': {'precision': 0.6677966101694915,
'recall': 0.788,
'f1-score': 0.7229357798165138,
'support': 250.0},
'7': {'precision': 0.8197424892703863,
'recall': 0.764,
'f1-score': 0.7908902691511387,
'support': 250.0},
'8': {'precision': 0.4025974025974026,
'recall': 0.372,
'f1-score': 0.3866943866943867,
'support': 250.0},
'9': {'precision': 0.6193181818181818,
'recall': 0.436,
'f1-score': 0.5117370892018779,
'support': 250.0},
'10': {'precision': 0.6235294117647059,
'recall': 0.424,
'f1-score': 0.5047619047619047,
'support': 250.0},
'11': {'precision': 0.8356164383561644,
'recall': 0.488,
'f1-score': 0.6161616161616161,
'support': 250.0},
'12': {'precision': 0.5196629213483146,
'recall': 0.74,
'f1-score': 0.6105610561056105,
'support': 250.0},
'13': {'precision': 0.5601503759398496,
'recall': 0.596,
'f1-score': 0.5775193798449613,
'support': 250.0},
'14': {'precision': 0.5584905660377358,
'recall': 0.592,
'f1-score': 0.574757281553398,
'support': 250.0},
'15': {'precision': 0.4388888888888889,
'recall': 0.316,
'f1-score': 0.3674418604651163,
'support': 250.0},
'16': {'precision': 0.4530120481927711,
'recall': 0.752,
'f1-score': 0.5654135338345865,
'support': 250.0},
'17': {'precision': 0.3659942363112392,
'recall': 0.508,
'f1-score': 0.42546063651591287,
'support': 250.0},
'18': {'precision': 0.4318840579710145,
'recall': 0.596,
'f1-score': 0.5008403361344538,
'support': 250.0},
'19': {'precision': 0.6832579185520362,
'recall': 0.604,
'f1-score': 0.6411889596602972,
'support': 250.0},
'20': {'precision': 0.68,
'recall': 0.748,
'f1-score': 0.7123809523809523,
'support': 250.0},
'21': {'precision': 0.350109409190372,
'recall': 0.64,
'f1-score': 0.4526166902404526,
'support': 250.0},
'22': {'precision': 0.2977346278317152,
'recall': 0.368,
'f1-score': 0.3291592128801431,
'support': 250.0},
'23': {'precision': 0.6632302405498282,
'recall': 0.772,
'f1-score': 0.7134935304990758,
'support': 250.0},
'24': {'precision': 0.8294930875576036,
'recall': 0.72,
'f1-score': 0.7708779443254818,
'support': 250.0},
'25': {'precision': 0.7574468085106383,
'recall': 0.712,
'f1-score': 0.734020618556701,
'support': 250.0},
'26': {'precision': 0.5147058823529411,
'recall': 0.42,
'f1-score': 0.46255506607929514,
'support': 250.0},
'27': {'precision': 0.776824034334764,
'recall': 0.724,
'f1-score': 0.7494824016563147,
'support': 250.0},
'28': {'precision': 0.6991869918699187,
'recall': 0.688,
'f1-score': 0.6935483870967742,
'support': 250.0},
'29': {'precision': 0.7024793388429752,
'recall': 0.68,
'f1-score': 0.6910569105691057,
'support': 250.0},
'30': {'precision': 0.9235294117647059,
'recall': 0.628,
'f1-score': 0.7476190476190476,
'support': 250.0},
'31': {'precision': 0.7802690582959642,
'recall': 0.696,
'f1-score': 0.7357293868921776,
'support': 250.0},
'32': {'precision': 0.7472924187725631,
'recall': 0.828,
'f1-score': 0.7855787476280834,
'support': 250.0},
'33': {'precision': 0.8945454545454545,
'recall': 0.984,
'f1-score': 0.9371428571428572,
'support': 250.0},
'34': {'precision': 0.6783216783216783,
'recall': 0.776,
'f1-score': 0.7238805970149254,
'support': 250.0},
'35': {'precision': 0.7819905213270142,
'recall': 0.66,
'f1-score': 0.7158351409978309,
'support': 250.0},
'36': {'precision': 0.5320754716981132,
'recall': 0.564,
'f1-score': 0.5475728155339806,
'support': 250.0},
'37': {'precision': 0.29912663755458513,
'recall': 0.548,
'f1-score': 0.3870056497175141,
'support': 250.0},
'38': {'precision': 0.7772277227722773,
'recall': 0.628,
'f1-score': 0.6946902654867256,
'support': 250.0},
'39': {'precision': 0.2694805194805195,
'recall': 0.332,
'f1-score': 0.2974910394265233,
'support': 250.0},
'40': {'precision': 0.7214285714285714,
'recall': 0.808,
'f1-score': 0.7622641509433963,
'support': 250.0},
'41': {'precision': 0.8115183246073299,
'recall': 0.62,
'f1-score': 0.7029478458049887,
'support': 250.0},
'42': {'precision': 0.5,
'recall': 0.58,
'f1-score': 0.5370370370370371,
'support': 250.0},
'43': {'precision': 0.746268656716418,
'recall': 0.6,
'f1-score': 0.6651884700665188,
'support': 250.0},
'44': {'precision': 0.7417218543046358,
'recall': 0.448,
'f1-score': 0.5586034912718204,
'support': 250.0},
'45': {'precision': 0.7745454545454545,
'recall': 0.852,
'f1-score': 0.8114285714285714,
'support': 250.0},
'46': {'precision': 0.8085106382978723,
'recall': 0.456,
'f1-score': 0.5831202046035806,
'support': 250.0},
'47': {'precision': 0.4392857142857143,
'recall': 0.492,
'f1-score': 0.4641509433962264,
'support': 250.0},
'48': {'precision': 0.4481236203090508,
'recall': 0.812,
'f1-score': 0.577524893314367,
'support': 250.0},
'49': {'precision': 0.5045454545454545,
'recall': 0.444,
'f1-score': 0.4723404255319149,
'support': 250.0},
'50': {'precision': 0.5444444444444444,
'recall': 0.392,
'f1-score': 0.4558139534883721,
'support': 250.0},
'51': {'precision': 0.7081967213114754,
'recall': 0.864,
'f1-score': 0.7783783783783784,
'support': 250.0},
'52': {'precision': 0.5092838196286472,
'recall': 0.768,
'f1-score': 0.6124401913875598,
'support': 250.0},
'53': {'precision': 0.6719367588932806,
'recall': 0.68,
'f1-score': 0.6759443339960238,
'support': 250.0},
'54': {'precision': 0.8785046728971962,
'recall': 0.752,
'f1-score': 0.8103448275862069,
'support': 250.0},
'55': {'precision': 0.86,
'recall': 0.688,
'f1-score': 0.7644444444444445,
'support': 250.0},
'56': {'precision': 0.5596330275229358,
'recall': 0.244,
'f1-score': 0.3398328690807799,
'support': 250.0},
'57': {'precision': 0.6222222222222222,
'recall': 0.448,
'f1-score': 0.5209302325581395,
'support': 250.0},
'58': {'precision': 0.6792452830188679,
'recall': 0.576,
'f1-score': 0.6233766233766234,
'support': 250.0},
'59': {'precision': 0.7045454545454546,
'recall': 0.372,
'f1-score': 0.4869109947643979,
'support': 250.0},
'60': {'precision': 0.8305084745762712,
'recall': 0.588,
'f1-score': 0.6885245901639344,
'support': 250.0},
'61': {'precision': 0.543010752688172,
'recall': 0.808,
'f1-score': 0.6495176848874598,
'support': 250.0},
'62': {'precision': 0.7218934911242604,
'recall': 0.488,
'f1-score': 0.5823389021479713,
'support': 250.0},
'63': {'precision': 0.9385964912280702,
'recall': 0.856,
'f1-score': 0.895397489539749,
'support': 250.0},
'64': {'precision': 0.7773722627737226,
'recall': 0.852,
'f1-score': 0.8129770992366412,
'support': 250.0},
'65': {'precision': 0.82, 'recall': 0.82, 'f1-score': 0.82, 'support': 250.0},
'66': {'precision': 0.6923076923076923,
'recall': 0.324,
'f1-score': 0.44141689373297005,
'support': 250.0},
'67': {'precision': 0.4090909090909091,
'recall': 0.576,
'f1-score': 0.47840531561461797,
'support': 250.0},
'68': {'precision': 0.8981481481481481,
'recall': 0.776,
'f1-score': 0.8326180257510729,
'support': 250.0},
'69': {'precision': 0.8442622950819673,
'recall': 0.824,
'f1-score': 0.8340080971659919,
'support': 250.0},
'70': {'precision': 0.6216216216216216,
'recall': 0.828,
'f1-score': 0.7101200686106347,
'support': 250.0},
'71': {'precision': 0.8111888111888111,
'recall': 0.464,
'f1-score': 0.5903307888040712,
'support': 250.0},
'72': {'precision': 0.6417322834645669,
'recall': 0.652,
'f1-score': 0.6468253968253969,
'support': 250.0},
'73': {'precision': 0.5091743119266054,
'recall': 0.444,
'f1-score': 0.47435897435897434,
'support': 250.0},
'74': {'precision': 0.7169811320754716,
'recall': 0.608,
'f1-score': 0.658008658008658,
'support': 250.0},
'75': {'precision': 0.8389513108614233,
'recall': 0.896,
'f1-score': 0.8665377176015474,
'support': 250.0},
'76': {'precision': 0.7777777777777778,
'recall': 0.784,
'f1-score': 0.7808764940239044,
'support': 250.0},
'77': {'precision': 0.3641304347826087,
'recall': 0.268,
'f1-score': 0.3087557603686636,
'support': 250.0},
'78': {'precision': 0.7863247863247863,
'recall': 0.736,
'f1-score': 0.7603305785123967,
'support': 250.0},
'79': {'precision': 0.44130434782608696,
'recall': 0.812,
'f1-score': 0.571830985915493,
'support': 250.0},
'80': {'precision': 0.5747126436781609,
'recall': 0.6,
'f1-score': 0.5870841487279843,
'support': 250.0},
'81': {'precision': 0.6529850746268657,
'recall': 0.7,
'f1-score': 0.6756756756756757,
'support': 250.0},
'82': {'precision': 0.3804878048780488,
'recall': 0.312,
'f1-score': 0.34285714285714286,
'support': 250.0},
'83': {'precision': 0.5780346820809249,
'recall': 0.8,
'f1-score': 0.6711409395973155,
'support': 250.0},
'84': {'precision': 0.6103896103896104,
'recall': 0.376,
'f1-score': 0.46534653465346537,
'support': 250.0},
'85': {'precision': 0.4423076923076923,
'recall': 0.736,
'f1-score': 0.5525525525525525,
'support': 250.0},
'86': {'precision': 0.7081967213114754,
'recall': 0.864,
'f1-score': 0.7783783783783784,
'support': 250.0},
'87': {'precision': 0.40756302521008403,
'recall': 0.388,
'f1-score': 0.3975409836065574,
'support': 250.0},
'88': {'precision': 0.8264462809917356,
'recall': 0.8,
'f1-score': 0.8130081300813008,
'support': 250.0},
'89': {'precision': 0.7129629629629629,
'recall': 0.308,
'f1-score': 0.4301675977653631,
'support': 250.0},
'90': {'precision': 0.9153439153439153,
'recall': 0.692,
'f1-score': 0.7881548974943052,
'support': 250.0},
'91': {'precision': 0.8282442748091603,
'recall': 0.868,
'f1-score': 0.84765625,
'support': 250.0},
'92': {'precision': 0.6835443037974683,
'recall': 0.648,
'f1-score': 0.6652977412731006,
'support': 250.0},
'93': {'precision': 0.3114754098360656,
'recall': 0.38,
'f1-score': 0.34234234234234234,
'support': 250.0},
'94': {'precision': 0.6118721461187214,
'recall': 0.536,
'f1-score': 0.5714285714285714,
'support': 250.0},
'95': {'precision': 0.7427184466019418,
'recall': 0.612,
'f1-score': 0.6710526315789473,
'support': 250.0},
'96': {'precision': 0.5625,
'recall': 0.288,
'f1-score': 0.38095238095238093,
'support': 250.0},
'97': {'precision': 0.4547677261613692,
'recall': 0.744,
'f1-score': 0.5644916540212443,
'support': 250.0},
'98': {'precision': 0.4685714285714286,
'recall': 0.328,
'f1-score': 0.38588235294117645,
'support': 250.0},
'99': {'precision': 0.5193798449612403,
'recall': 0.268,
'f1-score': 0.35356200527704484,
'support': 250.0},
'100': {'precision': 0.5912162162162162,
'recall': 0.7,
'f1-score': 0.6410256410256411,
'support': 250.0},
'accuracy': 0.6077623762376237,
'macro avg': {'precision': 0.6328666845830312,
'recall': 0.6077623762376237,
'f1-score': 0.6061252197245782,
'support': 25250.0},
'weighted avg': {'precision': 0.6328666845830311,
'recall': 0.6077623762376237,
'f1-score': 0.6061252197245781,
'support': 25250.0}}
There's still quite a few values. So we'll narrow it to f1-score, for it's combination of both metrics.
To extract it, we'll need to create an empty dictionary which we'll name class_f1_scores, and then loop function it through each item of classification_report_dict. Appending class name with f1-score as the key.
classification_report_dict.items()
dict_items([('0', {'precision': 0.29310344827586204, 'recall': 0.204, 'f1-score': 0.24056603773584906, 'support': 250.0}), ('1', {'precision': 0.5088235294117647, 'recall': 0.692, 'f1-score': 0.5864406779661017, 'support': 250.0}), ('2', {'precision': 0.5625, 'recall': 0.648, 'f1-score': 0.6022304832713755, 'support': 250.0}), ('3', {'precision': 0.7415730337078652, 'recall': 0.528, 'f1-score': 0.616822429906542, 'support': 250.0}), ('4', {'precision': 0.7346938775510204, 'recall': 0.432, 'f1-score': 0.5440806045340051, 'support': 250.0}), ('5', {'precision': 0.34177215189873417, 'recall': 0.54, 'f1-score': 0.4186046511627907, 'support': 250.0}), ('6', {'precision': 0.6677966101694915, 'recall': 0.788, 'f1-score': 0.7229357798165138, 'support': 250.0}), ('7', {'precision': 0.8197424892703863, 'recall': 0.764, 'f1-score': 0.7908902691511387, 'support': 250.0}), ('8', {'precision': 0.4025974025974026, 'recall': 0.372, 'f1-score': 0.3866943866943867, 'support': 250.0}), ('9', {'precision': 0.6193181818181818, 'recall': 0.436, 'f1-score': 0.5117370892018779, 'support': 250.0}), ('10', {'precision': 0.6235294117647059, 'recall': 0.424, 'f1-score': 0.5047619047619047, 'support': 250.0}), ('11', {'precision': 0.8356164383561644, 'recall': 0.488, 'f1-score': 0.6161616161616161, 'support': 250.0}), ('12', {'precision': 0.5196629213483146, 'recall': 0.74, 'f1-score': 0.6105610561056105, 'support': 250.0}), ('13', {'precision': 0.5601503759398496, 'recall': 0.596, 'f1-score': 0.5775193798449613, 'support': 250.0}), ('14', {'precision': 0.5584905660377358, 'recall': 0.592, 'f1-score': 0.574757281553398, 'support': 250.0}), ('15', {'precision': 0.4388888888888889, 'recall': 0.316, 'f1-score': 0.3674418604651163, 'support': 250.0}), ('16', {'precision': 0.4530120481927711, 'recall': 0.752, 'f1-score': 0.5654135338345865, 'support': 250.0}), ('17', {'precision': 0.3659942363112392, 'recall': 0.508, 'f1-score': 0.42546063651591287, 'support': 250.0}), ('18', {'precision': 0.4318840579710145, 'recall': 0.596, 'f1-score': 0.5008403361344538, 'support': 250.0}), ('19', {'precision': 0.6832579185520362, 'recall': 0.604, 'f1-score': 0.6411889596602972, 'support': 250.0}), ('20', {'precision': 0.68, 'recall': 0.748, 'f1-score': 0.7123809523809523, 'support': 250.0}), ('21', {'precision': 0.350109409190372, 'recall': 0.64, 'f1-score': 0.4526166902404526, 'support': 250.0}), ('22', {'precision': 0.2977346278317152, 'recall': 0.368, 'f1-score': 0.3291592128801431, 'support': 250.0}), ('23', {'precision': 0.6632302405498282, 'recall': 0.772, 'f1-score': 0.7134935304990758, 'support': 250.0}), ('24', {'precision': 0.8294930875576036, 'recall': 0.72, 'f1-score': 0.7708779443254818, 'support': 250.0}), ('25', {'precision': 0.7574468085106383, 'recall': 0.712, 'f1-score': 0.734020618556701, 'support': 250.0}), ('26', {'precision': 0.5147058823529411, 'recall': 0.42, 'f1-score': 0.46255506607929514, 'support': 250.0}), ('27', {'precision': 0.776824034334764, 'recall': 0.724, 'f1-score': 0.7494824016563147, 'support': 250.0}), ('28', {'precision': 0.6991869918699187, 'recall': 0.688, 'f1-score': 0.6935483870967742, 'support': 250.0}), ('29', {'precision': 0.7024793388429752, 'recall': 0.68, 'f1-score': 0.6910569105691057, 'support': 250.0}), ('30', {'precision': 0.9235294117647059, 'recall': 0.628, 'f1-score': 0.7476190476190476, 'support': 250.0}), ('31', {'precision': 0.7802690582959642, 'recall': 0.696, 'f1-score': 0.7357293868921776, 'support': 250.0}), ('32', {'precision': 0.7472924187725631, 'recall': 0.828, 'f1-score': 0.7855787476280834, 'support': 250.0}), ('33', {'precision': 0.8945454545454545, 'recall': 0.984, 'f1-score': 0.9371428571428572, 'support': 250.0}), ('34', {'precision': 0.6783216783216783, 'recall': 0.776, 'f1-score': 0.7238805970149254, 'support': 250.0}), ('35', {'precision': 0.7819905213270142, 'recall': 0.66, 'f1-score': 0.7158351409978309, 'support': 250.0}), ('36', {'precision': 0.5320754716981132, 'recall': 0.564, 'f1-score': 0.5475728155339806, 'support': 250.0}), ('37', {'precision': 0.29912663755458513, 'recall': 0.548, 'f1-score': 0.3870056497175141, 'support': 250.0}), ('38', {'precision': 0.7772277227722773, 'recall': 0.628, 'f1-score': 0.6946902654867256, 'support': 250.0}), ('39', {'precision': 0.2694805194805195, 'recall': 0.332, 'f1-score': 0.2974910394265233, 'support': 250.0}), ('40', {'precision': 0.7214285714285714, 'recall': 0.808, 'f1-score': 0.7622641509433963, 'support': 250.0}), ('41', {'precision': 0.8115183246073299, 'recall': 0.62, 'f1-score': 0.7029478458049887, 'support': 250.0}), ('42', {'precision': 0.5, 'recall': 0.58, 'f1-score': 0.5370370370370371, 'support': 250.0}), ('43', {'precision': 0.746268656716418, 'recall': 0.6, 'f1-score': 0.6651884700665188, 'support': 250.0}), ('44', {'precision': 0.7417218543046358, 'recall': 0.448, 'f1-score': 0.5586034912718204, 'support': 250.0}), ('45', {'precision': 0.7745454545454545, 'recall': 0.852, 'f1-score': 0.8114285714285714, 'support': 250.0}), ('46', {'precision': 0.8085106382978723, 'recall': 0.456, 'f1-score': 0.5831202046035806, 'support': 250.0}), ('47', {'precision': 0.4392857142857143, 'recall': 0.492, 'f1-score': 0.4641509433962264, 'support': 250.0}), ('48', {'precision': 0.4481236203090508, 'recall': 0.812, 'f1-score': 0.577524893314367, 'support': 250.0}), ('49', {'precision': 0.5045454545454545, 'recall': 0.444, 'f1-score': 0.4723404255319149, 'support': 250.0}), ('50', {'precision': 0.5444444444444444, 'recall': 0.392, 'f1-score': 0.4558139534883721, 'support': 250.0}), ('51', {'precision': 0.7081967213114754, 'recall': 0.864, 'f1-score': 0.7783783783783784, 'support': 250.0}), ('52', {'precision': 0.5092838196286472, 'recall': 0.768, 'f1-score': 0.6124401913875598, 'support': 250.0}), ('53', {'precision': 0.6719367588932806, 'recall': 0.68, 'f1-score': 0.6759443339960238, 'support': 250.0}), ('54', {'precision': 0.8785046728971962, 'recall': 0.752, 'f1-score': 0.8103448275862069, 'support': 250.0}), ('55', {'precision': 0.86, 'recall': 0.688, 'f1-score': 0.7644444444444445, 'support': 250.0}), ('56', {'precision': 0.5596330275229358, 'recall': 0.244, 'f1-score': 0.3398328690807799, 'support': 250.0}), ('57', {'precision': 0.6222222222222222, 'recall': 0.448, 'f1-score': 0.5209302325581395, 'support': 250.0}), ('58', {'precision': 0.6792452830188679, 'recall': 0.576, 'f1-score': 0.6233766233766234, 'support': 250.0}), ('59', {'precision': 0.7045454545454546, 'recall': 0.372, 'f1-score': 0.4869109947643979, 'support': 250.0}), ('60', {'precision': 0.8305084745762712, 'recall': 0.588, 'f1-score': 0.6885245901639344, 'support': 250.0}), ('61', {'precision': 0.543010752688172, 'recall': 0.808, 'f1-score': 0.6495176848874598, 'support': 250.0}), ('62', {'precision': 0.7218934911242604, 'recall': 0.488, 'f1-score': 0.5823389021479713, 'support': 250.0}), ('63', {'precision': 0.9385964912280702, 'recall': 0.856, 'f1-score': 0.895397489539749, 'support': 250.0}), ('64', {'precision': 0.7773722627737226, 'recall': 0.852, 'f1-score': 0.8129770992366412, 'support': 250.0}), ('65', {'precision': 0.82, 'recall': 0.82, 'f1-score': 0.82, 'support': 250.0}), ('66', {'precision': 0.6923076923076923, 'recall': 0.324, 'f1-score': 0.44141689373297005, 'support': 250.0}), ('67', {'precision': 0.4090909090909091, 'recall': 0.576, 'f1-score': 0.47840531561461797, 'support': 250.0}), ('68', {'precision': 0.8981481481481481, 'recall': 0.776, 'f1-score': 0.8326180257510729, 'support': 250.0}), ('69', {'precision': 0.8442622950819673, 'recall': 0.824, 'f1-score': 0.8340080971659919, 'support': 250.0}), ('70', {'precision': 0.6216216216216216, 'recall': 0.828, 'f1-score': 0.7101200686106347, 'support': 250.0}), ('71', {'precision': 0.8111888111888111, 'recall': 0.464, 'f1-score': 0.5903307888040712, 'support': 250.0}), ('72', {'precision': 0.6417322834645669, 'recall': 0.652, 'f1-score': 0.6468253968253969, 'support': 250.0}), ('73', {'precision': 0.5091743119266054, 'recall': 0.444, 'f1-score': 0.47435897435897434, 'support': 250.0}), ('74', {'precision': 0.7169811320754716, 'recall': 0.608, 'f1-score': 0.658008658008658, 'support': 250.0}), ('75', {'precision': 0.8389513108614233, 'recall': 0.896, 'f1-score': 0.8665377176015474, 'support': 250.0}), ('76', {'precision': 0.7777777777777778, 'recall': 0.784, 'f1-score': 0.7808764940239044, 'support': 250.0}), ('77', {'precision': 0.3641304347826087, 'recall': 0.268, 'f1-score': 0.3087557603686636, 'support': 250.0}), ('78', {'precision': 0.7863247863247863, 'recall': 0.736, 'f1-score': 0.7603305785123967, 'support': 250.0}), ('79', {'precision': 0.44130434782608696, 'recall': 0.812, 'f1-score': 0.571830985915493, 'support': 250.0}), ('80', {'precision': 0.5747126436781609, 'recall': 0.6, 'f1-score': 0.5870841487279843, 'support': 250.0}), ('81', {'precision': 0.6529850746268657, 'recall': 0.7, 'f1-score': 0.6756756756756757, 'support': 250.0}), ('82', {'precision': 0.3804878048780488, 'recall': 0.312, 'f1-score': 0.34285714285714286, 'support': 250.0}), ('83', {'precision': 0.5780346820809249, 'recall': 0.8, 'f1-score': 0.6711409395973155, 'support': 250.0}), ('84', {'precision': 0.6103896103896104, 'recall': 0.376, 'f1-score': 0.46534653465346537, 'support': 250.0}), ('85', {'precision': 0.4423076923076923, 'recall': 0.736, 'f1-score': 0.5525525525525525, 'support': 250.0}), ('86', {'precision': 0.7081967213114754, 'recall': 0.864, 'f1-score': 0.7783783783783784, 'support': 250.0}), ('87', {'precision': 0.40756302521008403, 'recall': 0.388, 'f1-score': 0.3975409836065574, 'support': 250.0}), ('88', {'precision': 0.8264462809917356, 'recall': 0.8, 'f1-score': 0.8130081300813008, 'support': 250.0}), ('89', {'precision': 0.7129629629629629, 'recall': 0.308, 'f1-score': 0.4301675977653631, 'support': 250.0}), ('90', {'precision': 0.9153439153439153, 'recall': 0.692, 'f1-score': 0.7881548974943052, 'support': 250.0}), ('91', {'precision': 0.8282442748091603, 'recall': 0.868, 'f1-score': 0.84765625, 'support': 250.0}), ('92', {'precision': 0.6835443037974683, 'recall': 0.648, 'f1-score': 0.6652977412731006, 'support': 250.0}), ('93', {'precision': 0.3114754098360656, 'recall': 0.38, 'f1-score': 0.34234234234234234, 'support': 250.0}), ('94', {'precision': 0.6118721461187214, 'recall': 0.536, 'f1-score': 0.5714285714285714, 'support': 250.0}), ('95', {'precision': 0.7427184466019418, 'recall': 0.612, 'f1-score': 0.6710526315789473, 'support': 250.0}), ('96', {'precision': 0.5625, 'recall': 0.288, 'f1-score': 0.38095238095238093, 'support': 250.0}), ('97', {'precision': 0.4547677261613692, 'recall': 0.744, 'f1-score': 0.5644916540212443, 'support': 250.0}), ('98', {'precision': 0.4685714285714286, 'recall': 0.328, 'f1-score': 0.38588235294117645, 'support': 250.0}), ('99', {'precision': 0.5193798449612403, 'recall': 0.268, 'f1-score': 0.35356200527704484, 'support': 250.0}), ('100', {'precision': 0.5912162162162162, 'recall': 0.7, 'f1-score': 0.6410256410256411, 'support': 250.0}), ('accuracy', 0.6077623762376237), ('macro avg', {'precision': 0.6328666845830312, 'recall': 0.6077623762376237, 'f1-score': 0.6061252197245782, 'support': 25250.0}), ('weighted avg', {'precision': 0.6328666845830311, 'recall': 0.6077623762376237, 'f1-score': 0.6061252197245781, 'support': 25250.0})])
# create empty dictionary
class_f1_scores = {}
# loop through classification report items
for k, v in classification_report_dict.items():
if k == 'accuracy': # stop once we get to accuracy key of the dictionary - so class report won't catch the non class items beyond int '101'
break
else:
# append class names and f1-scores to new dictionary
class_f1_scores[class_names[int(k)]] = v['f1-score'] # get classname via k's int value, while extracting 'f1-score' key
class_f1_scores
{'apple_pie': 0.24056603773584906,
'baby_back_ribs': 0.5864406779661017,
'baklava': 0.6022304832713755,
'beef_carpaccio': 0.616822429906542,
'beef_tartare': 0.5440806045340051,
'beet_salad': 0.4186046511627907,
'beignets': 0.7229357798165138,
'bibimbap': 0.7908902691511387,
'bread_pudding': 0.3866943866943867,
'breakfast_burrito': 0.5117370892018779,
'bruschetta': 0.5047619047619047,
'caesar_salad': 0.6161616161616161,
'cannoli': 0.6105610561056105,
'caprese_salad': 0.5775193798449613,
'carrot_cake': 0.574757281553398,
'ceviche': 0.3674418604651163,
'cheese_plate': 0.5654135338345865,
'cheesecake': 0.42546063651591287,
'chicken_curry': 0.5008403361344538,
'chicken_quesadilla': 0.6411889596602972,
'chicken_wings': 0.7123809523809523,
'chocolate_cake': 0.4526166902404526,
'chocolate_mousse': 0.3291592128801431,
'churros': 0.7134935304990758,
'clam_chowder': 0.7708779443254818,
'club_sandwich': 0.734020618556701,
'crab_cakes': 0.46255506607929514,
'creme_brulee': 0.7494824016563147,
'croque_madame': 0.6935483870967742,
'cup_cakes': 0.6910569105691057,
'deviled_eggs': 0.7476190476190476,
'donuts': 0.7357293868921776,
'dumplings': 0.7855787476280834,
'edamame': 0.9371428571428572,
'eggs_benedict': 0.7238805970149254,
'escargots': 0.7158351409978309,
'falafel': 0.5475728155339806,
'filet_mignon': 0.3870056497175141,
'fish_and_chips': 0.6946902654867256,
'foie_gras': 0.2974910394265233,
'french_fries': 0.7622641509433963,
'french_onion_soup': 0.7029478458049887,
'french_toast': 0.5370370370370371,
'fried_calamari': 0.6651884700665188,
'fried_rice': 0.5586034912718204,
'frozen_yogurt': 0.8114285714285714,
'garlic_bread': 0.5831202046035806,
'gnocchi': 0.4641509433962264,
'greek_salad': 0.577524893314367,
'grilled_cheese_sandwich': 0.4723404255319149,
'grilled_salmon': 0.4558139534883721,
'guacamole': 0.7783783783783784,
'gyoza': 0.6124401913875598,
'hamburger': 0.6759443339960238,
'hot_and_sour_soup': 0.8103448275862069,
'hot_dog': 0.7644444444444445,
'huevos_rancheros': 0.3398328690807799,
'hummus': 0.5209302325581395,
'ice_cream': 0.6233766233766234,
'lasagna': 0.4869109947643979,
'lobster_bisque': 0.6885245901639344,
'lobster_roll_sandwich': 0.6495176848874598,
'macaroni_and_cheese': 0.5823389021479713,
'macarons': 0.895397489539749,
'miso_soup': 0.8129770992366412,
'mussels': 0.82,
'nachos': 0.44141689373297005,
'omelette': 0.47840531561461797,
'onion_rings': 0.8326180257510729,
'oysters': 0.8340080971659919,
'pad_thai': 0.7101200686106347,
'paella': 0.5903307888040712,
'pancakes': 0.6468253968253969,
'panna_cotta': 0.47435897435897434,
'peking_duck': 0.658008658008658,
'pho': 0.8665377176015474,
'pizza': 0.7808764940239044,
'pork_chop': 0.3087557603686636,
'poutine': 0.7603305785123967,
'prime_rib': 0.571830985915493,
'pulled_pork_sandwich': 0.5870841487279843,
'ramen': 0.6756756756756757,
'ravioli': 0.34285714285714286,
'red_velvet_cake': 0.6711409395973155,
'risotto': 0.46534653465346537,
'samosa': 0.5525525525525525,
'sashimi': 0.7783783783783784,
'scallops': 0.3975409836065574,
'seaweed_salad': 0.8130081300813008,
'shrimp_and_grits': 0.4301675977653631,
'spaghetti_bolognese': 0.7881548974943052,
'spaghetti_carbonara': 0.84765625,
'spring_rolls': 0.6652977412731006,
'steak': 0.34234234234234234,
'strawberry_shortcake': 0.5714285714285714,
'sushi': 0.6710526315789473,
'tacos': 0.38095238095238093,
'takoyaki': 0.5644916540212443,
'tiramisu': 0.38588235294117645,
'tuna_tartare': 0.35356200527704484,
'waffles': 0.6410256410256411}
Looks good!
Seems the dictionary was ordered in alphabetical order. However we can try order them differently.
We can turn class_f1_scores dictionary into pandas DataFrame and sort it in ascending fashion via f1-score.
!x:\miniconda3\envs\tfenv\python -m pip install pandas
Requirement already satisfied: pandas in x:\miniconda3\envs\tfenv\lib\site-packages (2.3.3) Requirement already satisfied: numpy>=1.22.4 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (1.24.3) Requirement already satisfied: python-dateutil>=2.8.2 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (2025.2) Requirement already satisfied: tzdata>=2022.7 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (2025.2) Requirement already satisfied: six>=1.5 in x:\miniconda3\envs\tfenv\lib\site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)
# turn f1-scores into dataframe for visualization
import pandas as pd
f1_scores = pd.DataFrame({'class_names': list(class_f1_scores.keys()),
'f1-score': list(class_f1_scores.values())}).sort_values('f1-score', ascending=False)
f1_scores.head()
| class_names | f1-score | |
|---|---|---|
| 33 | edamame | 0.937143 |
| 63 | macarons | 0.895397 |
| 75 | pho | 0.866538 |
| 91 | spaghetti_carbonara | 0.847656 |
| 69 | oysters | 0.834008 |
Let's finish it off with a horizontal bar chart
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 25))
scores = ax.barh(range(len(f1_scores)), f1_scores["f1-score"].values)
ax.set_yticks(range(len(f1_scores)))
ax.set_yticklabels(list(f1_scores["class_names"]))
ax.set_xlabel("f1-score")
ax.set_title("F1-Scores for 10 Different Classes")
ax.invert_yaxis(); # reverse the order
def autolabel(rects): # Modified version of: https://matplotlib.org/examples/api/barchart_demo.html
"""
Attach a text label above each bar displaying its height (it's value).
"""
for rect in rects:
width = rect.get_width()
ax.text(1.03*width, rect.get_y() + rect.get_height()/1.5,
f"{width:.2f}",
ha='center', va='bottom')
autolabel(scores)
Visualizing performance makes a world of a difference. Previously we only had a list upon list of numbers with variable names. Now we get an indication of how well the model predicts different classes.
Findings like the individual performance of classes, allows us to figure out possible next step directions for our experiments. Perhaps we could collect more data of apple_pie or foie_gras for training the worst performing classes, or maybe its just visually difficult to differentiate them amongst other classes.
Exercise: Visualize the 3 worst performing classes, and see if there are any trends or clues to them.
Visualizing predictions on test images¶
We can look at numbers and graphs all we want, but you won't really know how the model performs unless we visually see the model performing on an image.
The model can't predict any image we throw at it. It must first be loaded into a tensor.
To begin, we will create a function to load an image into a tensor:
- Read in a target filepath using
tf.io.read_file(). - Turn the image into a
Tensorusingtf.io.decode_image(). - Resize the image to be the same size as the images our model has been trained on (224x224) using
tf.image.resize(). - Scale the image to get all the pixel values between 0 & 1 if necessary.
def load_and_prep_image(filename, img_shape=224, scale=True):
'''
Reads an image from filename, turns it into a tensor and reshapes into
(224,224,3)
Parameters:
filename (str): string filename of target image
img_shape (int): size to resize target image to. (default set to 224)
scale (bool): whether to scale pixel values to range 0 to 1 for normalization - default is True
'''
# read in the image
img = tf.io.read_file(filename)
# decode it into a tensor
img = tf.io.decode_image(img)
# resize the image
img = tf.image.resize(img, [img_shape, img_shape])
if scale:
# rescale the image so values are between 0 and 1
return img/255.
else:
return img
Our preprocessing function is complete.
Lets write some code for:
- Load a few random images from test dataset.
- Make predictions with them
- Plot the original image(s) along with the model's predicted label, prediction probability of said label, and ground truth label.
# make predictions on a series of random images
import os
import random
plt.figure(figsize=(17,10))
for i in range(3):
# choose a random image and random class
class_name = random.choice(class_names) # pick random class name
filename = random.choice(os.listdir(test_dir + '/' + class_name)) # get into dir of class name, and pick random filename in dir
filepath = test_dir + class_name + '/' + filename # create filepath text with our class name and file name
# load the image and make predictions
img = load_and_prep_image(filepath, scale=False) # use preprocess function - not using scale, as EfficientNet has done it for us
pred_prob = model.predict(tf.expand_dims(img, axis=0)) # model's shape is [None,224,224,3] - and must match for it to predict, as `None` is for our batch number
pred_class = class_names[pred_prob.argmax()] # get max probability value's index location
# plot the image(s)
plt.subplot(1,3,i+1)
plt.imshow(img/255.)
if class_name == pred_class: # change colour of text on whether it is correct or not
title_color = 'g'
else:
title_color = 'r'
plt.title(f'actual: {class_name}, pred: {pred_class}, prob: {pred_prob.max():.2f}', c=title_color)
plt.axis(False);
1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 67ms/step 1/1 [==============================] - 0s 51ms/step
Going through multiple re-runs, you can clearly see how the model makes wrong predictions with foods that may look eerily similar to other dishes.
Finding the most wrong predictions¶
It's a good idea to go through 100+ random instances of the model's predictions to get a good idea for how it's doing.
You may notice that the model has high confidence in a certain prediction, but the class turns out to be the wrong one.
These most wrong predictions can help give further insight into the model's performance.
So why not write code to collect all predictions with very high probability (0.95+), but the predicted class was wrong.
We'll go through these steps:
- Get all of the image file paths in test dataset using the
list_files()method. - Create a pandas DataFrame of the image filepaths, true labels, prediction class, max prediction probability, true label names and predicted class names.
- Note: We don't necessarily have to create a DataFrame like this, but it'll help with visualization.
- Use our DataFrame to find all the wrong predictions (where true label doesn't match to predicted class).
- Sort the DataFrame based on wrong predictions and highest max prediction probabilities.
- Visualize images with highest prediction probability, but has the wrong prediction.
# 1. get the filenames of all of the test data
filepaths = []
for filepath in test_data.list_files('101_food_classes_10_percent/test/*/*.jpg',
shuffle=False):
filepaths.append(filepath.numpy())
filepaths[:10]
[b'101_food_classes_10_percent\\test\\apple_pie\\1011328.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\101251.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\1034399.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\103801.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\1038694.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\1047447.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\1068632.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\110043.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\1106961.jpg', b'101_food_classes_10_percent\\test\\apple_pie\\1113017.jpg']
Now we have all test image file paths, let's combine them into a DataFrame along with:
- true class index (
y_label) - predicted class index (
pred_classes) - max probability prediction (
pred_probs.max(axis=1)) - true class name
- predicted class name
# 2. create a dataframe out of current prediction data for analysis
import pandas as pd
pred_df = pd.DataFrame({'img_path': filepaths,
'y_true': y_labels,
'y_pred': pred_classes,
'pred_conf': pred_probs.max(axis=1),
'y_true_classname': [class_names[i] for i in y_labels],
'y_pred_classname': [class_names[i] for i in pred_classes]})
pred_df.head()
| img_path | y_true | y_pred | pred_conf | y_true_classname | y_pred_classname | |
|---|---|---|---|---|---|---|
| 0 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 52 | 0.847418 | apple_pie | gyoza |
| 1 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 0 | 0.964017 | apple_pie | apple_pie |
| 2 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 0 | 0.959259 | apple_pie | apple_pie |
| 3 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 80 | 0.658607 | apple_pie | pulled_pork_sandwich |
| 4 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 79 | 0.367903 | apple_pie | prime_rib |
# 3. use our dataframe to find all wrong predictions
pred_df['pred_correct'] = pred_df['y_true'] == pred_df['y_pred']
pred_df.head()
| img_path | y_true | y_pred | pred_conf | y_true_classname | y_pred_classname | pred_correct | |
|---|---|---|---|---|---|---|---|
| 0 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 52 | 0.847418 | apple_pie | gyoza | False |
| 1 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 0 | 0.964017 | apple_pie | apple_pie | True |
| 2 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 0 | 0.959259 | apple_pie | apple_pie | True |
| 3 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 80 | 0.658607 | apple_pie | pulled_pork_sandwich | False |
| 4 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 79 | 0.367903 | apple_pie | prime_rib | False |
# 4. sort dataframe based on highes prediction probability
top_100_wrong = pred_df[pred_df['pred_correct'] == False].sort_values('pred_conf', ascending=False)[:100]
top_100_wrong.head(20)
| img_path | y_true | y_pred | pred_conf | y_true_classname | y_pred_classname | pred_correct | |
|---|---|---|---|---|---|---|---|
| 21810 | b'101_food_classes_10_percent\\test\\scallops\... | 87 | 29 | 0.999997 | scallops | cup_cakes | False |
| 231 | b'101_food_classes_10_percent\\test\\apple_pie... | 0 | 100 | 0.999995 | apple_pie | waffles | False |
| 15359 | b'101_food_classes_10_percent\\test\\lobster_r... | 61 | 53 | 0.999988 | lobster_roll_sandwich | hamburger | False |
| 23539 | b'101_food_classes_10_percent\\test\\strawberr... | 94 | 83 | 0.999987 | strawberry_shortcake | red_velvet_cake | False |
| 21400 | b'101_food_classes_10_percent\\test\\samosa\\3... | 85 | 92 | 0.999981 | samosa | spring_rolls | False |
| 24540 | b'101_food_classes_10_percent\\test\\tiramisu\... | 98 | 83 | 0.999947 | tiramisu | red_velvet_cake | False |
| 2511 | b'101_food_classes_10_percent\\test\\bruschett... | 10 | 61 | 0.999945 | bruschetta | lobster_roll_sandwich | False |
| 5574 | b'101_food_classes_10_percent\\test\\chocolate... | 22 | 21 | 0.999939 | chocolate_mousse | chocolate_cake | False |
| 17855 | b'101_food_classes_10_percent\\test\\paella\\2... | 71 | 65 | 0.999931 | paella | mussels | False |
| 23797 | b'101_food_classes_10_percent\\test\\sushi\\16... | 95 | 86 | 0.999904 | sushi | sashimi | False |
| 18001 | b'101_food_classes_10_percent\\test\\pancakes\... | 72 | 67 | 0.999903 | pancakes | omelette | False |
| 11642 | b'101_food_classes_10_percent\\test\\garlic_br... | 46 | 10 | 0.999877 | garlic_bread | bruschetta | False |
| 10847 | b'101_food_classes_10_percent\\test\\fried_cal... | 43 | 68 | 0.999872 | fried_calamari | onion_rings | False |
| 23631 | b'101_food_classes_10_percent\\test\\strawberr... | 94 | 83 | 0.999858 | strawberry_shortcake | red_velvet_cake | False |
| 1155 | b'101_food_classes_10_percent\\test\\beef_tart... | 4 | 5 | 0.999858 | beef_tartare | beet_salad | False |
| 10854 | b'101_food_classes_10_percent\\test\\fried_cal... | 43 | 68 | 0.999854 | fried_calamari | onion_rings | False |
| 23904 | b'101_food_classes_10_percent\\test\\sushi\\33... | 95 | 86 | 0.999823 | sushi | sashimi | False |
| 7316 | b'101_food_classes_10_percent\\test\\cup_cakes... | 29 | 83 | 0.999817 | cup_cakes | red_velvet_cake | False |
| 13144 | b'101_food_classes_10_percent\\test\\gyoza\\31... | 52 | 92 | 0.999799 | gyoza | spring_rolls | False |
| 10880 | b'101_food_classes_10_percent\\test\\fried_cal... | 43 | 68 | 0.999778 | fried_calamari | onion_rings | False |
#5. visualize some of the most wrong examples
images_to_view = 9
start_index = 10 # change the start index to view more
plt.figure(figsize=(15,10))
for i, row in enumerate(top_100_wrong[start_index:start_index+images_to_view].itertuples()):
plt.subplot(3,3,i+1)
img = load_and_prep_image(row[1], scale=True)
_,_,_,_,pred_prob,y_true,y_pred,_= row # only interested in a few parameters from each row
plt.imshow(img)
plt.title(f'actual: {y_true}, pred: {y_pred} \nprob: {pred_prob:.2f}')
plt.axis(False)
Looking at the most wrong predictions, we can note a full things:
- Some of the labels might be wrong - If our model ends up being good enough, it may be able to predict so well, and get the right label of a class, that may have been mislabeled in our testing dataset. In that case, we could use the model to help us improve the labeling of our data, in turn making future models better. This is called
active learning.
From looking at the top left image, the model predicted omellete, but true class is supposed to be pancake. Just at a glance, we notice a very yellow object that looks roundish and flat in shape. Even though pancakes do involve eggs, they are incorporated into a paste mix consisting of flower and milk. Not a separate side dish to the food.
- More samples needed - If there's a certain class thats consistently being classified as another class, perhaps it's a good idea to gather more samples of both classes with different scenarios for the model to further improve on.
Looking at the top right and middle right image, the model predicted calamari and onion rings the other way around. And visually seeing both images, tells us how similar these food classes are to each other, which therefore would be a good idea to find more examples to help the model differentiate them.
Test out the big dog model on test images, as well as custom images of food¶
We've visualized some of the model's predictions from test dataset. Noe its time to use our model and predict on our custom images of food.
Let's download and unzip a prepeared folder of third party food images.
!curl -O https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip
unzip_data('custom_food_images.zip')
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
2 12.5M 2 281k 0 0 241k 0 0:00:53 0:00:01 0:00:52 242k
70 12.5M 70 9118k 0 0 4211k 0 0:00:03 0:00:02 0:00:01 4215k
100 12.5M 100 12.5M 0 0 5063k 0 0:00:02 0:00:02 --:--:-- 5068k
We can load these in, and turn them into tensors with load_and_prep_image() function. But first, we need a list of image filepaths
# get custom food image filepaths
custom_food_images = ['custom_food_images/' + img_path for img_path in os.listdir('custom_food_images')]
custom_food_images
['custom_food_images/hamburger.jpeg', 'custom_food_images/steak.jpeg', 'custom_food_images/sushi.jpeg', 'custom_food_images/chicken_wings.jpeg', 'custom_food_images/ramen.jpeg', 'custom_food_images/pizza-dad.jpeg']
We can now use similar code to what was used previously to load in our images. Make a prediction on each using themodel, and then plot the image alongside it's prediction.
# make predictions on custom food images
for img in custom_food_images:
img = load_and_prep_image(img, scale=False) # load in target image and turn it into tensor
pred_prob = model.predict(tf.expand_dims(img, axis=0)) # expand dims for model to put batch num
pred_class = class_names[pred_prob.argmax()] # find predicted class label
# plot the image
plt.figure()
plt.imshow(img/255.) # requiring float input to be normalized
plt.title(f'pred: {pred_class}, prob: {pred_prob.max():.2f}')
plt.axis(False)
1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 57ms/step 1/1 [==============================] - 0s 52ms/step 1/1 [==============================] - 0s 51ms/step